Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Speaker Extraction Service | |
| Extracts specific speaker from audio using reference clip and cosine similarity matching. | |
| Uses pyannote.audio embedding model for speaker verification. | |
| """ | |
| import logging | |
| import time | |
| from pathlib import Path | |
| from typing import Callable, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Create a no-op decorator for environments without spaces package | |
| class spaces: | |
| def GPU(duration=60): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| # Workaround for PyTorch 2.6+ weights_only security feature | |
| # pyannote models are from trusted source (HuggingFace) | |
| # Monkey-patch torch.load to use weights_only=False for pyannote models | |
| _original_torch_load = torch.load | |
| def _patched_torch_load(*args, **kwargs): | |
| # Force weights_only=False since we trust pyannote models from HuggingFace | |
| kwargs["weights_only"] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = _patched_torch_load | |
| from pyannote.audio import Pipeline | |
| from src.config.gpu_config import GPUConfig | |
| from src.lib.audio_io import get_audio_duration, read_audio, write_audio | |
| from src.lib.progress import SPEAKER_EXTRACTION_STAGES | |
| from src.models.audio_segment import AudioSegment, SegmentType | |
| from src.models.error_report import ErrorReport | |
| from src.services.audio_concatenation import AudioConcatenationUtility | |
| logger = logging.getLogger(__name__) | |
| # Module-level GPU functions to avoid pickling issues with ZeroGPU | |
| def _extract_embedding_on_gpu(audio_dict: Dict, hf_token: str) -> np.ndarray: | |
| """ | |
| Extract speaker embedding on GPU (or CPU if unavailable). | |
| This is a module-level function to avoid pickling issues with ZeroGPU. | |
| The model is loaded fresh within this GPU context. | |
| Args: | |
| audio_dict: Audio data dict with 'waveform' and 'sample_rate' | |
| hf_token: HuggingFace token for model access | |
| Returns: | |
| Speaker embedding vector | |
| """ | |
| from pyannote.audio import Inference, Model | |
| # Load model fresh in GPU context (avoids pickling) | |
| logger.info("Loading embedding model in GPU context...") | |
| model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token) | |
| # Move to available device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| logger.info(f"Embedding model loaded on {device}") | |
| # Create inference wrapper | |
| embedding_model = Inference(model, window="whole") | |
| try: | |
| embedding = embedding_model(audio_dict) | |
| # Embedding is already a numpy array from Inference | |
| if isinstance(embedding, torch.Tensor): | |
| embedding = embedding.detach().cpu().numpy() | |
| # Flatten if needed | |
| if len(embedding.shape) > 1: | |
| embedding = embedding.flatten() | |
| logger.info(f"Extracted {len(embedding)}-dimensional embedding") | |
| return embedding | |
| finally: | |
| # Clean up | |
| del embedding_model | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def _extract_embeddings_batch_on_gpu( | |
| audio_data: np.ndarray, | |
| sample_rate: int, | |
| segments: List[AudioSegment], | |
| hf_token: str, | |
| progress_callback: Optional[Callable] = None, | |
| ) -> List[Tuple[AudioSegment, np.ndarray]]: | |
| """ | |
| Extract embeddings for multiple segments on GPU. | |
| This is a module-level function to avoid pickling issues with ZeroGPU. | |
| The model is loaded fresh within this GPU context. | |
| Args: | |
| audio_data: Full audio array | |
| sample_rate: Sample rate | |
| segments: List of AudioSegment objects to process | |
| hf_token: HuggingFace token for model access | |
| progress_callback: Optional progress callback | |
| Returns: | |
| List of (AudioSegment, embedding) tuples | |
| """ | |
| from pyannote.audio import Inference, Model | |
| # Load model fresh in GPU context (avoids pickling) | |
| logger.info("Loading embedding model in GPU context...") | |
| model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token) | |
| # Move to available device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| logger.info(f"Embedding model loaded on {device}") | |
| # Create inference wrapper | |
| embedding_model = Inference(model, window="whole") | |
| try: | |
| segments_with_embeddings = [] | |
| for i, segment in enumerate(segments): | |
| if progress_callback: | |
| # Progress from 0.15 to 0.40 for embedding computation | |
| embed_progress = 0.15 + (0.25 * (i + 1) / len(segments)) | |
| progress_callback( | |
| SPEAKER_EXTRACTION_STAGES[1], embed_progress, 1.0 | |
| ) # "Computing embeddings" | |
| # Extract segment audio | |
| start_sample = int(segment.start_time * sample_rate) | |
| end_sample = int(segment.end_time * sample_rate) | |
| segment_audio = audio_data[start_sample:end_sample] | |
| # Skip if segment too short | |
| if len(segment_audio) < sample_rate * 0.5: # 0.5 second minimum | |
| continue | |
| # Extract embedding | |
| audio_tensor = torch.from_numpy(segment_audio).unsqueeze(0) | |
| audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate} | |
| embedding = embedding_model(audio_dict) | |
| # Embedding is already a numpy array from Inference | |
| if isinstance(embedding, torch.Tensor): | |
| embedding = embedding.detach().cpu().numpy() | |
| # Flatten if needed | |
| if len(embedding.shape) > 1: | |
| embedding = embedding.flatten() | |
| segments_with_embeddings.append((segment, embedding)) | |
| logger.info(f"Extracted embeddings from {len(segments_with_embeddings)} segments") | |
| return segments_with_embeddings | |
| finally: | |
| # Clean up | |
| del embedding_model | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| class SpeakerExtractionService: | |
| """ | |
| Service for extracting specific speaker from audio files using reference clips. | |
| Uses speaker embeddings and cosine similarity to match segments. | |
| """ | |
| def __init__(self): | |
| """Initialize speaker extraction service""" | |
| import os | |
| # Store HF token for GPU functions to use | |
| self.hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
| if not self.hf_token: | |
| raise ValueError( | |
| "HuggingFace token required. Set HF_TOKEN or HUGGINGFACE_TOKEN environment variable." | |
| ) | |
| # Initialize audio concatenation utility | |
| self.audio_concatenator = AudioConcatenationUtility() | |
| logger.info("Speaker extraction service initialized") | |
| def extract_reference_embedding(self, reference_clip_path: str) -> np.ndarray: | |
| """ | |
| Extract speaker embedding from reference clip. | |
| Args: | |
| reference_clip_path: Path to reference audio clip | |
| Returns: | |
| Speaker embedding vector (512-dimensional) | |
| Raises: | |
| ValueError: If reference clip is too short or invalid | |
| """ | |
| # Validate reference clip duration | |
| duration = get_audio_duration(reference_clip_path) | |
| if duration < 3.0: | |
| raise ValueError( | |
| f"Reference clip is {duration:.1f}s (minimum 3.0s required for reliable matching)" | |
| ) | |
| logger.info(f"Extracting embedding from reference clip ({duration:.1f}s)") | |
| # Read audio | |
| audio_data, sample_rate = read_audio(reference_clip_path, target_sr=16000) | |
| # Check audio quality | |
| rms = np.sqrt(np.mean(audio_data**2)) | |
| if rms < 0.01: | |
| logger.warning( | |
| f"Reference clip has low amplitude (RMS={rms:.4f}). " | |
| "Consider using a cleaner sample for better results." | |
| ) | |
| # Convert to torch tensor | |
| audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # Add batch dimension | |
| # Extract embedding using Inference model | |
| audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate} | |
| # Call module-level GPU function (avoids pickling self) | |
| embedding = _extract_embedding_on_gpu(audio_dict, self.hf_token) | |
| return embedding | |
| def detect_voice_segments( | |
| self, audio_path: str, min_duration: float = 0.5 | |
| ) -> List[AudioSegment]: | |
| """ | |
| Detect voice activity segments in audio file using simple chunking. | |
| For now, we use fixed-size chunks since VAD requires additional model access. | |
| In production, this should use proper VAD. | |
| Args: | |
| audio_path: Path to audio file | |
| min_duration: Minimum segment duration in seconds | |
| Returns: | |
| List of AudioSegment objects for voice activity | |
| """ | |
| logger.info(f"Detecting voice segments in {Path(audio_path).name}...") | |
| # Simple approach: split into fixed chunks (can be improved with VAD) | |
| duration = get_audio_duration(audio_path) | |
| # Create 5-second chunks (good balance for embedding extraction) | |
| chunk_duration = 5.0 | |
| segments = [] | |
| current_time = 0.0 | |
| while current_time < duration: | |
| end_time = min(current_time + chunk_duration, duration) | |
| if end_time - current_time >= min_duration: | |
| audio_segment = AudioSegment( | |
| start_time=current_time, | |
| end_time=end_time, | |
| speaker_id="UNKNOWN", | |
| confidence=1.0, | |
| segment_type=SegmentType.SPEECH, | |
| ) | |
| segments.append(audio_segment) | |
| current_time = end_time | |
| logger.info(f"Created {len(segments)} segments ({chunk_duration}s chunks)") | |
| return segments | |
| def extract_target_embeddings( | |
| self, target_audio_path: str, progress_callback: Optional[Callable] = None | |
| ) -> List[Tuple[AudioSegment, np.ndarray]]: | |
| """ | |
| Extract embeddings from all voice segments in target audio. | |
| Args: | |
| target_audio_path: Path to target audio file | |
| progress_callback: Optional callback for progress updates (stage, current, total) | |
| Returns: | |
| List of tuples (AudioSegment, embedding) | |
| """ | |
| # Detect voice segments | |
| segments = self.detect_voice_segments(target_audio_path) | |
| if len(segments) == 0: | |
| logger.warning("No voice segments detected in target audio") | |
| return [] | |
| # Load full audio | |
| audio_data, sample_rate = read_audio(target_audio_path, target_sr=16000) | |
| # Call module-level GPU function (avoids pickling self) | |
| segments_with_embeddings = _extract_embeddings_batch_on_gpu( | |
| audio_data=audio_data, | |
| sample_rate=sample_rate, | |
| segments=segments, | |
| hf_token=self.hf_token, | |
| progress_callback=progress_callback, | |
| ) | |
| return segments_with_embeddings | |
| def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: | |
| """ | |
| Compute cosine similarity between two embeddings. | |
| Args: | |
| embedding1: First embedding vector | |
| embedding2: Second embedding vector | |
| Returns: | |
| Cosine similarity score (-1 to 1, higher is more similar) | |
| """ | |
| # Normalize embeddings | |
| norm1 = np.linalg.norm(embedding1) | |
| norm2 = np.linalg.norm(embedding2) | |
| if norm1 == 0 or norm2 == 0: | |
| return 0.0 | |
| # Compute cosine similarity | |
| similarity = np.dot(embedding1, embedding2) / (norm1 * norm2) | |
| return float(similarity) | |
| def match_segments( | |
| self, | |
| reference_embedding: np.ndarray, | |
| segments_with_embeddings: List[Tuple[AudioSegment, np.ndarray]], | |
| threshold: float = 0.40, | |
| min_confidence: float = 0.30, | |
| ) -> List[Tuple[AudioSegment, float]]: | |
| """ | |
| Match segments against reference embedding using similarity threshold. | |
| Args: | |
| reference_embedding: Reference speaker embedding | |
| segments_with_embeddings: List of (segment, embedding) tuples | |
| threshold: Similarity threshold (lower is stricter, 0.0-1.0) | |
| min_confidence: Minimum segment confidence to include | |
| Returns: | |
| List of (segment, similarity_score) tuples for matched segments | |
| """ | |
| matched = [] | |
| for segment, embedding in segments_with_embeddings: | |
| # Filter by segment confidence | |
| if segment.confidence < min_confidence: | |
| continue | |
| # Compute similarity | |
| similarity = self.compute_similarity(reference_embedding, embedding) | |
| # Match if similarity exceeds threshold | |
| # Note: threshold is inverted - lower threshold = stricter matching | |
| # We use (1 - threshold) as the actual similarity threshold | |
| similarity_threshold = 1.0 - threshold | |
| if similarity >= similarity_threshold: | |
| matched.append((segment, similarity)) | |
| logger.info( | |
| f"Matched {len(matched)}/{len(segments_with_embeddings)} segments " | |
| f"(threshold={threshold:.2f}, min_confidence={min_confidence:.2f})" | |
| ) | |
| return matched | |
| def validate_reference_clip(self, reference_clip_path: str) -> Tuple[bool, str]: | |
| """ | |
| Validate reference clip quality and duration. | |
| Args: | |
| reference_clip_path: Path to reference clip | |
| Returns: | |
| Tuple of (is_valid, message) | |
| """ | |
| try: | |
| # Check duration | |
| duration = get_audio_duration(reference_clip_path) | |
| if duration < 3.0: | |
| return False, f"Reference clip is {duration:.1f}s (minimum 3.0s required)" | |
| # Check audio quality | |
| audio_data, sample_rate = read_audio(reference_clip_path) | |
| rms = np.sqrt(np.mean(audio_data**2)) | |
| if rms < 0.01: | |
| return ( | |
| True, | |
| f"Warning: Low audio quality (RMS={rms:.4f}). Consider using cleaner sample.", | |
| ) | |
| return True, "Reference clip is valid" | |
| except Exception as e: | |
| return False, f"Error validating reference clip: {str(e)}" | |
| def extract_and_export( | |
| self, | |
| reference_clip: str, | |
| target_file: str, | |
| output_path: str, | |
| threshold: float = 0.40, | |
| min_confidence: float = 0.30, | |
| concatenate: bool = True, | |
| silence_duration_ms: int = 150, | |
| crossfade_duration_ms: int = 75, | |
| sample_rate: int = 44100, | |
| bitrate: str = "192k", | |
| progress_callback: Optional[Callable] = None, | |
| ) -> Dict: | |
| """ | |
| Extract speaker from target file and export to audio file. | |
| Args: | |
| reference_clip: Path to reference clip of target speaker | |
| target_file: Path to target audio file | |
| output_path: Path for output file(s) | |
| threshold: Speaker matching threshold (0.0-1.0, lower is stricter) | |
| min_confidence: Minimum confidence for including segments | |
| concatenate: If True, concatenate segments; if False, export separately | |
| silence_duration_ms: Silence duration between concatenated segments | |
| crossfade_duration_ms: Crossfade duration for smooth transitions | |
| sample_rate: Output sample rate | |
| bitrate: Output bitrate | |
| progress_callback: Optional callback for progress updates | |
| Returns: | |
| Extraction report dictionary or ErrorReport on failure | |
| """ | |
| start_time = time.time() | |
| logger.info(f"Extracting speaker from {Path(target_file).name}") | |
| logger.info(f"Reference clip: {Path(reference_clip).name}") | |
| logger.info(f"Threshold: {threshold:.2f}, Min confidence: {min_confidence:.2f}") | |
| try: | |
| if progress_callback: | |
| progress_callback( | |
| SPEAKER_EXTRACTION_STAGES[0], 0.0, 1.0 | |
| ) # "Loading reference audio" | |
| # Extract reference embedding | |
| reference_embedding = self.extract_reference_embedding(reference_clip) | |
| except Exception as e: | |
| logger.error(f"Failed to extract reference embedding: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to extract reference embedding: {e}", | |
| "error_type": "audio_io", | |
| } | |
| return error_report | |
| try: | |
| if progress_callback: | |
| progress_callback(SPEAKER_EXTRACTION_STAGES[1], 0.15, 1.0) # "Computing embeddings" | |
| # Extract target embeddings | |
| # Note: progress_callback cannot be passed due to ZeroGPU pickling constraints | |
| segments_with_embeddings = self.extract_target_embeddings( | |
| target_file, | |
| progress_callback=None, # Cannot pass callback to avoid pickling errors | |
| ) | |
| if progress_callback: | |
| progress_callback( | |
| SPEAKER_EXTRACTION_STAGES[2], 0.4, 1.0 | |
| ) # "Matching voice segments" | |
| # Match segments | |
| matched_segments = self.match_segments( | |
| reference_embedding, | |
| segments_with_embeddings, | |
| threshold=threshold, | |
| min_confidence=min_confidence, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to process and match segments: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to process and match segments: {e}", | |
| "error_type": "processing", | |
| } | |
| return error_report | |
| if len(matched_segments) == 0: | |
| logger.warning("No matching segments found") | |
| report = self.generate_extraction_report( | |
| reference_clip=reference_clip, | |
| target_file=target_file, | |
| threshold=threshold, | |
| matched_segments=[], | |
| processing_time=time.time() - start_time, | |
| output_file=None, | |
| ) | |
| return report | |
| try: | |
| if progress_callback: | |
| progress_callback(SPEAKER_EXTRACTION_STAGES[3], 0.75, 1.0) # "Extracting segments" | |
| # Load target audio | |
| target_audio, target_sr = read_audio(target_file) | |
| except Exception as e: | |
| logger.error(f"Failed to load target audio: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to load target audio: {e}", | |
| "error_type": "audio_io", | |
| } | |
| return error_report | |
| try: | |
| # Export matched segments | |
| output_path_obj = Path(output_path) | |
| if concatenate: | |
| # Concatenate all matched segments | |
| segment_audio_list = [] | |
| for segment, similarity in matched_segments: | |
| start_sample = int(segment.start_time * target_sr) | |
| end_sample = int(segment.end_time * target_sr) | |
| segment_audio = target_audio[start_sample:end_sample] | |
| segment_audio_list.append(segment_audio) | |
| # Concatenate with crossfade | |
| concatenated = self.audio_concatenator.concatenate_segments( | |
| segment_audio_list, | |
| sample_rate=target_sr, | |
| silence_duration_ms=silence_duration_ms, | |
| crossfade_duration_ms=crossfade_duration_ms, | |
| ) | |
| # Resample if needed | |
| if sample_rate != target_sr: | |
| from src.lib.audio_io import resample_audio | |
| concatenated = resample_audio(concatenated, target_sr, sample_rate) | |
| output_sr = sample_rate | |
| else: | |
| output_sr = target_sr | |
| # Write output | |
| write_audio(str(output_path), concatenated, output_sr) | |
| logger.info(f"Exported concatenated audio to {output_path}") | |
| output_file = str(output_path) | |
| else: | |
| # Export segments separately | |
| output_path_obj.mkdir(parents=True, exist_ok=True) | |
| for i, (segment, similarity) in enumerate(matched_segments, start=1): | |
| start_sample = int(segment.start_time * target_sr) | |
| end_sample = int(segment.end_time * target_sr) | |
| segment_audio = target_audio[start_sample:end_sample] | |
| # Resample if needed | |
| if sample_rate != target_sr: | |
| from src.lib.audio_io import resample_audio | |
| segment_audio = resample_audio(segment_audio, target_sr, sample_rate) | |
| output_sr = sample_rate | |
| else: | |
| output_sr = target_sr | |
| segment_file = output_path_obj / f"segment_{i:03d}.m4a" | |
| write_audio(str(segment_file), segment_audio, output_sr) | |
| logger.info(f"Exported {len(matched_segments)} segments to {output_path_obj}") | |
| output_file = str(output_path_obj) | |
| if progress_callback: | |
| progress_callback("Complete", 1.0, 1.0) | |
| # Generate report | |
| report = self.generate_extraction_report( | |
| reference_clip=reference_clip, | |
| target_file=target_file, | |
| threshold=threshold, | |
| matched_segments=matched_segments, | |
| processing_time=time.time() - start_time, | |
| output_file=output_file, | |
| ) | |
| return report | |
| except Exception as e: | |
| logger.error(f"Failed to export speaker segments: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to export speaker segments: {e}", | |
| "error_type": "processing", | |
| } | |
| return error_report | |
| def generate_extraction_report( | |
| self, | |
| reference_clip: str, | |
| target_file: str, | |
| threshold: float, | |
| matched_segments: List[Tuple[AudioSegment, float]], | |
| processing_time: float, | |
| output_file: Optional[str], | |
| ) -> Dict: | |
| """ | |
| Generate extraction report JSON. | |
| Args: | |
| reference_clip: Reference clip path | |
| target_file: Target file path | |
| threshold: Matching threshold used | |
| matched_segments: List of matched (segment, similarity) tuples | |
| processing_time: Processing time in seconds | |
| output_file: Output file path | |
| Returns: | |
| Report dictionary | |
| """ | |
| total_duration = sum(seg.duration for seg, _ in matched_segments) | |
| avg_confidence = ( | |
| sum(similarity for _, similarity in matched_segments) / len(matched_segments) | |
| if matched_segments | |
| else 0.0 | |
| ) | |
| low_confidence = sum( | |
| 1 | |
| for _, similarity in matched_segments | |
| if similarity < (1.0 - threshold + 0.1) # Within 0.1 of threshold | |
| ) | |
| report = { | |
| "reference_clip": str(reference_clip), | |
| "target_file": str(target_file), | |
| "threshold": threshold, | |
| "segments_found": len(matched_segments), | |
| "segments_included": len(matched_segments), | |
| "total_duration_seconds": round(total_duration, 2), | |
| "average_confidence": round(avg_confidence, 3), | |
| "low_confidence_segments": low_confidence, | |
| "processing_time_seconds": round(processing_time, 1), | |
| "output_file": output_file, | |
| } | |
| return report | |