""" Speaker Separation Service Performs speaker diarization and separation using pyannote.audio. Extracts individual speakers from multi-speaker audio files. """ import json import logging import os import time from pathlib import Path from typing import Callable, Dict, List, Optional import numpy as np import torch try: import spaces except ImportError: # Create a no-op decorator for environments without spaces package class spaces: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator # Workaround for PyTorch 2.6+ weights_only security feature # pyannote models are from trusted source (HuggingFace) # Monkey-patch torch.load to use weights_only=False for pyannote models _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): # Force weights_only=False since we trust pyannote models from HuggingFace kwargs["weights_only"] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load from pyannote.audio import Pipeline from pyannote.audio.pipelines.utils.hook import ProgressHook from ..config.gpu_config import GPUConfig from ..lib.audio_io import ( AudioIOError, convert_m4a_to_wav, convert_wav_to_m4a, extract_segment, get_audio_duration, read_audio, write_audio, ) from ..lib.progress import SPEAKER_SEPARATION_STAGES from ..models.audio_segment import AudioSegment, SegmentType from ..models.error_report import ErrorReport from ..models.speaker_profile import SpeakerProfile logger = logging.getLogger(__name__) # Module-level function for GPU-accelerated diarization # This avoids pickling issues with ZeroGPU by not depending on class instance state @spaces.GPU(duration=90) def _run_diarization_on_gpu( audio_dict: Dict, hf_token: str, min_speakers: int, max_speakers: int, progress_callback: Optional[Callable] = None, ): """ Run diarization on GPU (or CPU if unavailable). This is a module-level function to avoid pickling issues with ZeroGPU. The pipeline is loaded fresh within this GPU context. Args: audio_dict: Audio data dict with 'waveform' and 'sample_rate' hf_token: HuggingFace token for model access min_speakers: Minimum number of speakers max_speakers: Maximum number of speakers progress_callback: Optional progress callback Returns: Diarization result from pyannote """ # Load pipeline fresh in GPU context (avoids pickling) logger.info("Loading pyannote pipeline in GPU context...") pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token) # Move to available device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline.to(device) logger.info(f"Pipeline loaded on {device}") try: # Custom progress hook that bridges pyannote progress to our callback class CustomProgressHook(ProgressHook): def __init__(self, callback=None): super().__init__() self.callback = callback def __call__(self, step_name, step_artefact, file=None, total=None, completed=None): # Call parent to maintain pyannote's internal tracking result = super().__call__(step_name, step_artefact, file, total, completed) # Forward progress to our callback if self.callback and completed is not None and total is not None and total > 0: # Map step names to user-friendly descriptions stage = SPEAKER_SEPARATION_STAGES.get(step_name, step_name) # Calculate percentage within this step (0.0 to 1.0) step_progress = completed / total # Scale to 0.3-0.8 range (30% to 80% of overall progress) overall_progress = 0.3 + (step_progress * 0.5) self.callback(stage, overall_progress, 1.0) return result # Use custom hook for pyannote progress with callback forwarding with CustomProgressHook(callback=progress_callback) as hook: diarization = pipeline( audio_dict, min_speakers=min_speakers, max_speakers=max_speakers, hook=hook ) if progress_callback: progress_callback("Speaker detection complete", 0.8, 1.0) # Count speakers by iterating through speaker_diarization speakers = set() for turn, speaker in diarization.speaker_diarization: speakers.add(speaker) logger.info(f"Detected {len(speakers)} speakers: {', '.join(sorted(speakers))}") return diarization finally: # Clean up del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() class SpeakerSeparationService: """ Service for speaker diarization and separation. Uses pyannote.audio for speaker diarization to identify and separate individual speakers from multi-speaker audio files. """ def __init__(self, hf_token: Optional[str] = None): """ Initialize speaker separation service. Args: hf_token: HuggingFace API token (required for pyannote models) If None, will try to get from HF_TOKEN env var Raises: ValueError: If HuggingFace token not provided """ if hf_token is None: hf_token = os.getenv("HF_TOKEN") if not hf_token: raise ValueError( "HuggingFace token required. Set HF_TOKEN environment " "variable or pass hf_token parameter." ) self.hf_token = hf_token def convert_to_wav(self, input_path: str, sample_rate: int = 16000) -> str: """ Convert M4A/AAC to WAV for pyannote processing. Args: input_path: Path to M4A file sample_rate: Target sample rate (default: 16000 for pyannote) Returns: Path to converted WAV file """ return convert_m4a_to_wav(input_path, sample_rate=sample_rate) def separate_speakers( self, audio_path: str, min_speakers: int = 2, max_speakers: int = 5, progress_callback: Optional[Callable] = None, ): """ Perform speaker diarization on audio file. Args: audio_path: Path to audio file (M4A or WAV) min_speakers: Minimum number of speakers to detect max_speakers: Maximum number of speakers to detect progress_callback: Optional callback for progress updates Returns: Diarization result from pyannote Raises: AudioIOError: If file cannot be read ValueError: If parameters are invalid """ if min_speakers > max_speakers: raise ValueError( f"min_speakers ({min_speakers}) cannot exceed max_speakers ({max_speakers})" ) # Convert M4A to WAV if needed audio_path = Path(audio_path) if not audio_path.exists(): raise AudioIOError(f"Audio file not found: {audio_path}") if audio_path.suffix.lower() in [".m4a", ".aac", ".mp4"]: logger.info(f"Converting {audio_path.name} to WAV for processing...") audio_path = Path(self.convert_to_wav(str(audio_path))) # Run diarization with progress reporting logger.info(f"Performing speaker diarization (min={min_speakers}, max={max_speakers})...") if progress_callback: progress_callback("Starting speaker detection", 0.0, 1.0) # Load audio ourselves and pass as dict to avoid torchcodec issues audio_data, sr = read_audio(str(audio_path), target_sr=16000) audio_dict = { "waveform": torch.from_numpy(audio_data).unsqueeze(0), # Add channel dimension "sample_rate": sr, } # Call the module-level GPU function (avoids pickling self) diarization = _run_diarization_on_gpu( audio_dict=audio_dict, hf_token=self.hf_token, min_speakers=min_speakers, max_speakers=max_speakers, progress_callback=progress_callback, ) return diarization def extract_speaker_segments(self, diarization, speaker_id: str) -> List[AudioSegment]: """ Extract audio segments for a specific speaker. Args: diarization: Diarization result from pyannote speaker_id: Speaker ID to extract (e.g., "SPEAKER_00") Returns: List of AudioSegment objects for this speaker """ segments = [] # pyannote.audio 4.0 API - iterate over speaker_diarization for turn, speaker in diarization.speaker_diarization: if speaker == speaker_id: audio_segment = AudioSegment( start_time=turn.start, end_time=turn.end, speaker_id=speaker_id, confidence=1.0, # pyannote doesn't provide per-segment confidence segment_type=SegmentType.SPEECH, ) segments.append(audio_segment) logger.debug(f"Extracted {len(segments)} segments for {speaker_id}") return segments def export_speaker_audio( self, audio: np.ndarray, sample_rate: int, output_path: str, output_sample_rate: int = 44100, bitrate: str = "192k", ) -> str: """ Export speaker audio to M4A format. Args: audio: Audio array sample_rate: Input sample rate output_path: Output M4A file path output_sample_rate: Output sample rate (default: 44100) bitrate: Output bitrate (default: "192k") Returns: Path to exported M4A file """ output_path = Path(output_path) # Create output directory output_path.parent.mkdir(parents=True, exist_ok=True) # First write to temporary WAV temp_wav = output_path.with_suffix(".temp.wav") write_audio(str(temp_wav), audio, sample_rate) # Convert to M4A m4a_path = convert_wav_to_m4a( str(temp_wav), str(output_path), sample_rate=output_sample_rate, bitrate=bitrate ) # Clean up temp file temp_wav.unlink() logger.info(f"Exported speaker audio to {output_path.name}") return m4a_path def generate_separation_report( self, input_file: str, speakers: List[str], segments: Dict[str, List[AudioSegment]], processing_time: float, output_files: List[Dict], input_duration: float, ) -> Dict: """ Generate separation report JSON. Args: input_file: Input file path speakers: List of speaker IDs segments: Dict mapping speaker IDs to their segments processing_time: Processing time in seconds output_files: List of output file information input_duration: Input audio duration in seconds Returns: Report dictionary """ # Calculate quality metrics total_segments = sum(len(segs) for segs in segments.values()) avg_confidence = sum(seg.confidence for segs in segments.values() for seg in segs) / max( total_segments, 1 ) # Count overlapping segments overlapping = 0 all_segs = [seg for segs in segments.values() for seg in segs] for i, seg1 in enumerate(all_segs): for seg2 in all_segs[i + 1 :]: if seg1.overlaps_with(seg2): overlapping += 1 report = { "input_file": str(input_file), "input_duration_seconds": input_duration, "speakers_detected": len(speakers), "processing_time_seconds": processing_time, "output_files": output_files, "overlapping_segments": overlapping, "quality_metrics": { "average_confidence": round(avg_confidence, 3), "total_segments": total_segments, "low_confidence_segments": sum( 1 for segs in segments.values() for seg in segs if seg.confidence < 0.7 ), }, } return report def separate_and_export( self, input_file: str, output_dir: str, min_speakers: int = 2, max_speakers: int = 5, output_format: str = "m4a", sample_rate: int = 44100, bitrate: str = "192k", progress_callback: Optional[Callable] = None, ) -> Dict: """ Complete workflow: separate speakers and export to individual files. Args: input_file: Input M4A audio file output_dir: Output directory for separated files min_speakers: Minimum speakers to detect max_speakers: Maximum speakers to detect output_format: Output format - m4a, wav, or mp3 (default: "m4a") sample_rate: Output sample rate (default: 44100) bitrate: Output bitrate (default: "192k") progress_callback: Optional progress callback Returns: Separation report dictionary or ErrorReport on failure """ start_time = time.time() try: input_file = Path(input_file) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Get input duration input_duration = get_audio_duration(str(input_file)) except Exception as e: logger.error(f"Failed to initialize speaker separation: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to initialize speaker separation: {e}", "error_type": "audio_io", } return error_report try: # Perform speaker diarization if progress_callback: progress_callback("Loading audio", 0.1, 1.0) # Note: progress_callback cannot be passed due to ZeroGPU pickling constraints diarization = self.separate_speakers( str(input_file), min_speakers=min_speakers, max_speakers=max_speakers, progress_callback=None, # Cannot pass callback to avoid pickling errors ) except Exception as e: logger.error(f"Speaker diarization failed: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Speaker diarization failed: {e}", "error_type": "processing", } return error_report try: # Get unique speakers by iterating through speaker_diarization speakers = set() for turn, speaker in diarization.speaker_diarization: speakers.add(speaker) speakers = sorted(list(speakers)) # Extract segments for each speaker segments = {} for speaker_id in speakers: segments[speaker_id] = self.extract_speaker_segments(diarization, speaker_id) # Load full audio for extraction if progress_callback: progress_callback("Performing speaker diarization", 0.2, 1.0) # Convert to WAV for processing if needed wav_path = input_file if input_file.suffix.lower() in [".m4a", ".aac", ".mp4"]: wav_path = Path(self.convert_to_wav(str(input_file), sample_rate=sample_rate)) audio, sr = read_audio(str(wav_path), target_sr=sample_rate) except Exception as e: logger.error(f"Failed to load and process audio: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to load and process audio: {e}", "error_type": "audio_io", } return error_report try: # Export each speaker output_files = [] for i, speaker_id in enumerate(speakers): if progress_callback: # Progress from 0.8 to 1.0 for speaker exports export_progress = 0.8 + (0.2 * (i + 1) / len(speakers)) progress_callback( f"Exporting speaker {i + 1}/{len(speakers)}", export_progress, 1.0 ) # Extract and concatenate all segments for this speaker speaker_segments = segments[speaker_id] speaker_audio_parts = [] for segment in speaker_segments: segment_audio = extract_segment(audio, sr, segment.start_time, segment.end_time) speaker_audio_parts.append(segment_audio) # Concatenate segments if speaker_audio_parts: speaker_audio = np.concatenate(speaker_audio_parts) # Export to M4A output_file = output_dir / f"speaker_{i:02d}.m4a" self.export_speaker_audio( speaker_audio, sr, str(output_file), output_sample_rate=sample_rate, bitrate=bitrate, ) output_files.append( { "speaker_id": speaker_id, "file": str(output_file), "duration": len(speaker_audio) / sr, "segments_count": len(speaker_segments), } ) # Generate and save report processing_time = time.time() - start_time report = self.generate_separation_report( input_file=str(input_file), speakers=speakers, segments=segments, processing_time=processing_time, output_files=output_files, input_duration=input_duration, ) # Write report JSON report_file = output_dir / "separation_report.json" with open(report_file, "w") as f: json.dump(report, f, indent=2) logger.info(f"Separation complete: {len(speakers)} speakers in {processing_time:.1f}s") if progress_callback: progress_callback("Complete", 1.0, 1.0) return report except Exception as e: logger.error(f"Failed to export speakers: {e}") error_report: ErrorReport = { "status": "failed", "error": f"Failed to export speakers: {e}", "error_type": "processing", } return error_report