""" Speaker Extraction Service Extracts specific speaker from audio using reference clip and cosine similarity matching. Uses pyannote.audio embedding model for speaker verification. """ import logging import time from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple import numpy as np import torch try: import spaces except ImportError: # Create a no-op decorator for environments without spaces package class spaces: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator # Workaround for PyTorch 2.6+ weights_only security feature # pyannote models are from trusted source (HuggingFace) # Monkey-patch torch.load to use weights_only=False for pyannote models _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): # Force weights_only=False since we trust pyannote models from HuggingFace kwargs["weights_only"] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load from pyannote.audio import Pipeline from src.config.gpu_config import GPUConfig from src.lib.audio_io import get_audio_duration, read_audio, write_audio from src.lib.progress import SPEAKER_EXTRACTION_STAGES from src.models.audio_segment import AudioSegment, SegmentType from src.models.error_report import ErrorReport from src.services.audio_concatenation import AudioConcatenationUtility logger = logging.getLogger(__name__) # Module-level GPU functions to avoid pickling issues with ZeroGPU @spaces.GPU(duration=60) def _extract_embedding_on_gpu(audio_dict: Dict, hf_token: str) -> np.ndarray: """ Extract speaker embedding on GPU (or CPU if unavailable). This is a module-level function to avoid pickling issues with ZeroGPU. The model is loaded fresh within this GPU context. Args: audio_dict: Audio data dict with 'waveform' and 'sample_rate' hf_token: HuggingFace token for model access Returns: Speaker embedding vector """ from pyannote.audio import Inference, Model # Load model fresh in GPU context (avoids pickling) logger.info("Loading embedding model in GPU context...") model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token) # Move to available device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) logger.info(f"Embedding model loaded on {device}") # Create inference wrapper embedding_model = Inference(model, window="whole") try: embedding = embedding_model(audio_dict) # Embedding is already a numpy array from Inference if isinstance(embedding, torch.Tensor): embedding = embedding.detach().cpu().numpy() # Flatten if needed if len(embedding.shape) > 1: embedding = embedding.flatten() logger.info(f"Extracted {len(embedding)}-dimensional embedding") return embedding finally: # Clean up del embedding_model del model if torch.cuda.is_available(): torch.cuda.empty_cache() @spaces.GPU(duration=60) def _extract_embeddings_batch_on_gpu( audio_data: np.ndarray, sample_rate: int, segments: List[AudioSegment], hf_token: str, progress_callback: Optional[Callable] = None, ) -> List[Tuple[AudioSegment, np.ndarray]]: """ Extract embeddings for multiple segments on GPU. This is a module-level function to avoid pickling issues with ZeroGPU. The model is loaded fresh within this GPU context. Args: audio_data: Full audio array sample_rate: Sample rate segments: List of AudioSegment objects to process hf_token: HuggingFace token for model access progress_callback: Optional progress callback Returns: List of (AudioSegment, embedding) tuples """ from pyannote.audio import Inference, Model # Load model fresh in GPU context (avoids pickling) logger.info("Loading embedding model in GPU context...") model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token) # Move to available device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) logger.info(f"Embedding model loaded on {device}") # Create inference wrapper embedding_model = Inference(model, window="whole") try: segments_with_embeddings = [] for i, segment in enumerate(segments): if progress_callback: # Progress from 0.15 to 0.40 for embedding computation embed_progress = 0.15 + (0.25 * (i + 1) / len(segments)) progress_callback( SPEAKER_EXTRACTION_STAGES[1], embed_progress, 1.0 ) # "Computing embeddings" # Extract segment audio start_sample = int(segment.start_time * sample_rate) end_sample = int(segment.end_time * sample_rate) segment_audio = audio_data[start_sample:end_sample] # Skip if segment too short if len(segment_audio) < sample_rate * 0.5: # 0.5 second minimum continue # Extract embedding audio_tensor = torch.from_numpy(segment_audio).unsqueeze(0) audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate} embedding = embedding_model(audio_dict) # Embedding is already a numpy array from Inference if isinstance(embedding, torch.Tensor): embedding = embedding.detach().cpu().numpy() # Flatten if needed if len(embedding.shape) > 1: embedding = embedding.flatten() segments_with_embeddings.append((segment, embedding)) logger.info(f"Extracted embeddings from {len(segments_with_embeddings)} segments") return segments_with_embeddings finally: # Clean up del embedding_model del model if torch.cuda.is_available(): torch.cuda.empty_cache() class SpeakerExtractionService: """ Service for extracting specific speaker from audio files using reference clips. Uses speaker embeddings and cosine similarity to match segments. """ def __init__(self): """Initialize speaker extraction service""" import os # Store HF token for GPU functions to use self.hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") if not self.hf_token: raise ValueError( "HuggingFace token required. Set HF_TOKEN or HUGGINGFACE_TOKEN environment variable." ) # Initialize audio concatenation utility self.audio_concatenator = AudioConcatenationUtility() logger.info("Speaker extraction service initialized") def extract_reference_embedding(self, reference_clip_path: str) -> np.ndarray: """ Extract speaker embedding from reference clip. Args: reference_clip_path: Path to reference audio clip Returns: Speaker embedding vector (512-dimensional) Raises: ValueError: If reference clip is too short or invalid """ # Validate reference clip duration duration = get_audio_duration(reference_clip_path) if duration < 3.0: raise ValueError( f"Reference clip is {duration:.1f}s (minimum 3.0s required for reliable matching)" ) logger.info(f"Extracting embedding from reference clip ({duration:.1f}s)") # Read audio audio_data, sample_rate = read_audio(reference_clip_path, target_sr=16000) # Check audio quality rms = np.sqrt(np.mean(audio_data**2)) if rms < 0.01: logger.warning( f"Reference clip has low amplitude (RMS={rms:.4f}). " "Consider using a cleaner sample for better results." ) # Convert to torch tensor audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # Add batch dimension # Extract embedding using Inference model audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate} # Call module-level GPU function (avoids pickling self) embedding = _extract_embedding_on_gpu(audio_dict, self.hf_token) return embedding def detect_voice_segments( self, audio_path: str, min_duration: float = 0.5 ) -> List[AudioSegment]: """ Detect voice activity segments in audio file using simple chunking. For now, we use fixed-size chunks since VAD requires additional model access. In production, this should use proper VAD. Args: audio_path: Path to audio file min_duration: Minimum segment duration in seconds Returns: List of AudioSegment objects for voice activity """ logger.info(f"Detecting voice segments in {Path(audio_path).name}...") # Simple approach: split into fixed chunks (can be improved with VAD) duration = get_audio_duration(audio_path) # Create 5-second chunks (good balance for embedding extraction) chunk_duration = 5.0 segments = [] current_time = 0.0 while current_time < duration: end_time = min(current_time + chunk_duration, duration) if end_time - current_time >= min_duration: audio_segment = AudioSegment( start_time=current_time, end_time=end_time, speaker_id="UNKNOWN", confidence=1.0, segment_type=SegmentType.SPEECH, ) segments.append(audio_segment) current_time = end_time logger.info(f"Created {len(segments)} segments ({chunk_duration}s chunks)") return segments def extract_target_embeddings( self, target_audio_path: str, progress_callback: Optional[Callable] = None ) -> List[Tuple[AudioSegment, np.ndarray]]: """ Extract embeddings from all voice segments in target audio. Args: target_audio_path: Path to target audio file progress_callback: Optional callback for progress updates (stage, current, total) Returns: List of tuples (AudioSegment, embedding) """ # Detect voice segments segments = self.detect_voice_segments(target_audio_path) if len(segments) == 0: logger.warning("No voice segments detected in target audio") return [] # Load full audio audio_data, sample_rate = read_audio(target_audio_path, target_sr=16000) # Call module-level GPU function (avoids pickling self) segments_with_embeddings = _extract_embeddings_batch_on_gpu( audio_data=audio_data, sample_rate=sample_rate, segments=segments, hf_token=self.hf_token, progress_callback=progress_callback, ) return segments_with_embeddings def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: """ Compute cosine similarity between two embeddings. Args: embedding1: First embedding vector embedding2: Second embedding vector Returns: Cosine similarity score (-1 to 1, higher is more similar) """ # Normalize embeddings norm1 = np.linalg.norm(embedding1) norm2 = np.linalg.norm(embedding2) if norm1 == 0 or norm2 == 0: return 0.0 # Compute cosine similarity similarity = np.dot(embedding1, embedding2) / (norm1 * norm2) return float(similarity) def match_segments( self, reference_embedding: np.ndarray, segments_with_embeddings: List[Tuple[AudioSegment, np.ndarray]], threshold: float = 0.40, min_confidence: float = 0.30, ) -> List[Tuple[AudioSegment, float]]: """ Match segments against reference embedding using similarity threshold. Args: reference_embedding: Reference speaker embedding segments_with_embeddings: List of (segment, embedding) tuples threshold: Similarity threshold (lower is stricter, 0.0-1.0) min_confidence: Minimum segment confidence to include Returns: List of (segment, similarity_score) tuples for matched segments """ matched = [] for segment, embedding in segments_with_embeddings: # Filter by segment confidence if segment.confidence < min_confidence: continue # Compute similarity similarity = self.compute_similarity(reference_embedding, embedding) # Match if similarity exceeds threshold # Note: threshold is inverted - lower threshold = stricter matching # We use (1 - threshold) as the actual similarity threshold similarity_threshold = 1.0 - threshold if similarity >= similarity_threshold: matched.append((segment, similarity)) logger.info( f"Matched {len(matched)}/{len(segments_with_embeddings)} segments " f"(threshold={threshold:.2f}, min_confidence={min_confidence:.2f})" ) return matched def validate_reference_clip(self, reference_clip_path: str) -> Tuple[bool, str]: """ Validate reference clip quality and duration. Args: reference_clip_path: Path to reference clip Returns: Tuple of (is_valid, message) """ try: # Check duration duration = get_audio_duration(reference_clip_path) if duration < 3.0: return False, f"Reference clip is {duration:.1f}s (minimum 3.0s required)" # Check audio quality audio_data, sample_rate = read_audio(reference_clip_path) rms = np.sqrt(np.mean(audio_data**2)) if rms < 0.01: return ( True, f"Warning: Low audio quality (RMS={rms:.4f}). Consider using cleaner sample.", ) return True, "Reference clip is valid" except Exception as e: return False, f"Error validating reference clip: {str(e)}" def extract_and_export( self, reference_clip: str, target_file: str, output_path: str, threshold: float = 0.40, min_confidence: float = 0.30, concatenate: bool = True, silence_duration_ms: int = 150, crossfade_duration_ms: int = 75, sample_rate: int = 44100, bitrate: str = "192k", progress_callback: Optional[Callable] = None, ) -> Dict: """ Extract speaker from target file and export to audio file. Args: reference_clip: Path to reference clip of target speaker target_file: Path to target audio file output_path: Path for output file(s) threshold: Speaker matching threshold (0.0-1.0, lower is stricter) min_confidence: Minimum confidence for including segments concatenate: If True, concatenate segments; if False, export separately silence_duration_ms: Silence duration between concatenated segments crossfade_duration_ms: Crossfade duration for smooth transitions sample_rate: Output sample rate bitrate: Output bitrate progress_callback: Optional callback for progress updates Returns: Extraction report dictionary or ErrorReport on failure """ start_time = time.time() logger.info(f"Extracting speaker from {Path(target_file).name}") logger.info(f"Reference clip: {Path(reference_clip).name}") logger.info(f"Threshold: {threshold:.2f}, Min confidence: {min_confidence:.2f}") try: if progress_callback: progress_callback( SPEAKER_EXTRACTION_STAGES[0], 0.0, 1.0 ) # "Loading reference audio" # Extract reference embedding reference_embedding = self.extract_reference_embedding(reference_clip) except Exception as e: logger.error(f"Failed to extract reference embedding: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to extract reference embedding: {e}", "error_type": "audio_io", } return error_report try: if progress_callback: progress_callback(SPEAKER_EXTRACTION_STAGES[1], 0.15, 1.0) # "Computing embeddings" # Extract target embeddings # Note: progress_callback cannot be passed due to ZeroGPU pickling constraints segments_with_embeddings = self.extract_target_embeddings( target_file, progress_callback=None, # Cannot pass callback to avoid pickling errors ) if progress_callback: progress_callback( SPEAKER_EXTRACTION_STAGES[2], 0.4, 1.0 ) # "Matching voice segments" # Match segments matched_segments = self.match_segments( reference_embedding, segments_with_embeddings, threshold=threshold, min_confidence=min_confidence, ) except Exception as e: logger.error(f"Failed to process and match segments: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to process and match segments: {e}", "error_type": "processing", } return error_report if len(matched_segments) == 0: logger.warning("No matching segments found") report = self.generate_extraction_report( reference_clip=reference_clip, target_file=target_file, threshold=threshold, matched_segments=[], processing_time=time.time() - start_time, output_file=None, ) return report try: if progress_callback: progress_callback(SPEAKER_EXTRACTION_STAGES[3], 0.75, 1.0) # "Extracting segments" # Load target audio target_audio, target_sr = read_audio(target_file) except Exception as e: logger.error(f"Failed to load target audio: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to load target audio: {e}", "error_type": "audio_io", } return error_report try: # Export matched segments output_path_obj = Path(output_path) if concatenate: # Concatenate all matched segments segment_audio_list = [] for segment, similarity in matched_segments: start_sample = int(segment.start_time * target_sr) end_sample = int(segment.end_time * target_sr) segment_audio = target_audio[start_sample:end_sample] segment_audio_list.append(segment_audio) # Concatenate with crossfade concatenated = self.audio_concatenator.concatenate_segments( segment_audio_list, sample_rate=target_sr, silence_duration_ms=silence_duration_ms, crossfade_duration_ms=crossfade_duration_ms, ) # Resample if needed if sample_rate != target_sr: from src.lib.audio_io import resample_audio concatenated = resample_audio(concatenated, target_sr, sample_rate) output_sr = sample_rate else: output_sr = target_sr # Write output write_audio(str(output_path), concatenated, output_sr) logger.info(f"Exported concatenated audio to {output_path}") output_file = str(output_path) else: # Export segments separately output_path_obj.mkdir(parents=True, exist_ok=True) for i, (segment, similarity) in enumerate(matched_segments, start=1): start_sample = int(segment.start_time * target_sr) end_sample = int(segment.end_time * target_sr) segment_audio = target_audio[start_sample:end_sample] # Resample if needed if sample_rate != target_sr: from src.lib.audio_io import resample_audio segment_audio = resample_audio(segment_audio, target_sr, sample_rate) output_sr = sample_rate else: output_sr = target_sr segment_file = output_path_obj / f"segment_{i:03d}.m4a" write_audio(str(segment_file), segment_audio, output_sr) logger.info(f"Exported {len(matched_segments)} segments to {output_path_obj}") output_file = str(output_path_obj) if progress_callback: progress_callback("Complete", 1.0, 1.0) # Generate report report = self.generate_extraction_report( reference_clip=reference_clip, target_file=target_file, threshold=threshold, matched_segments=matched_segments, processing_time=time.time() - start_time, output_file=output_file, ) return report except Exception as e: logger.error(f"Failed to export speaker segments: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to export speaker segments: {e}", "error_type": "processing", } return error_report def generate_extraction_report( self, reference_clip: str, target_file: str, threshold: float, matched_segments: List[Tuple[AudioSegment, float]], processing_time: float, output_file: Optional[str], ) -> Dict: """ Generate extraction report JSON. Args: reference_clip: Reference clip path target_file: Target file path threshold: Matching threshold used matched_segments: List of matched (segment, similarity) tuples processing_time: Processing time in seconds output_file: Output file path Returns: Report dictionary """ total_duration = sum(seg.duration for seg, _ in matched_segments) avg_confidence = ( sum(similarity for _, similarity in matched_segments) / len(matched_segments) if matched_segments else 0.0 ) low_confidence = sum( 1 for _, similarity in matched_segments if similarity < (1.0 - threshold + 0.1) # Within 0.1 of threshold ) report = { "reference_clip": str(reference_clip), "target_file": str(target_file), "threshold": threshold, "segments_found": len(matched_segments), "segments_included": len(matched_segments), "total_duration_seconds": round(total_duration, 2), "average_confidence": round(avg_confidence, 3), "low_confidence_segments": low_confidence, "processing_time_seconds": round(processing_time, 1), "output_file": output_file, } return report