Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Speaker Separation Service | |
| Performs speaker diarization and separation using pyannote.audio. | |
| Extracts individual speakers from multi-speaker audio files. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Callable, Dict, List, Optional | |
| import numpy as np | |
| import torch | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Create a no-op decorator for environments without spaces package | |
| class spaces: | |
| def GPU(duration=60): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| # Workaround for PyTorch 2.6+ weights_only security feature | |
| # pyannote models are from trusted source (HuggingFace) | |
| # Monkey-patch torch.load to use weights_only=False for pyannote models | |
| _original_torch_load = torch.load | |
| def _patched_torch_load(*args, **kwargs): | |
| # Force weights_only=False since we trust pyannote models from HuggingFace | |
| kwargs["weights_only"] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = _patched_torch_load | |
| from pyannote.audio import Pipeline | |
| from pyannote.audio.pipelines.utils.hook import ProgressHook | |
| from ..config.gpu_config import GPUConfig | |
| from ..lib.audio_io import ( | |
| AudioIOError, | |
| convert_m4a_to_wav, | |
| convert_wav_to_m4a, | |
| extract_segment, | |
| get_audio_duration, | |
| read_audio, | |
| write_audio, | |
| ) | |
| from ..lib.progress import SPEAKER_SEPARATION_STAGES | |
| from ..models.audio_segment import AudioSegment, SegmentType | |
| from ..models.error_report import ErrorReport | |
| from ..models.speaker_profile import SpeakerProfile | |
| logger = logging.getLogger(__name__) | |
| # Module-level function for GPU-accelerated diarization | |
| # This avoids pickling issues with ZeroGPU by not depending on class instance state | |
| def _run_diarization_on_gpu( | |
| audio_dict: Dict, | |
| hf_token: str, | |
| min_speakers: int, | |
| max_speakers: int, | |
| progress_callback: Optional[Callable] = None, | |
| ): | |
| """ | |
| Run diarization on GPU (or CPU if unavailable). | |
| This is a module-level function to avoid pickling issues with ZeroGPU. | |
| The pipeline is loaded fresh within this GPU context. | |
| Args: | |
| audio_dict: Audio data dict with 'waveform' and 'sample_rate' | |
| hf_token: HuggingFace token for model access | |
| min_speakers: Minimum number of speakers | |
| max_speakers: Maximum number of speakers | |
| progress_callback: Optional progress callback | |
| Returns: | |
| Diarization result from pyannote | |
| """ | |
| # Load pipeline fresh in GPU context (avoids pickling) | |
| logger.info("Loading pyannote pipeline in GPU context...") | |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token) | |
| # Move to available device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pipeline.to(device) | |
| logger.info(f"Pipeline loaded on {device}") | |
| try: | |
| # Custom progress hook that bridges pyannote progress to our callback | |
| class CustomProgressHook(ProgressHook): | |
| def __init__(self, callback=None): | |
| super().__init__() | |
| self.callback = callback | |
| def __call__(self, step_name, step_artefact, file=None, total=None, completed=None): | |
| # Call parent to maintain pyannote's internal tracking | |
| result = super().__call__(step_name, step_artefact, file, total, completed) | |
| # Forward progress to our callback | |
| if self.callback and completed is not None and total is not None and total > 0: | |
| # Map step names to user-friendly descriptions | |
| stage = SPEAKER_SEPARATION_STAGES.get(step_name, step_name) | |
| # Calculate percentage within this step (0.0 to 1.0) | |
| step_progress = completed / total | |
| # Scale to 0.3-0.8 range (30% to 80% of overall progress) | |
| overall_progress = 0.3 + (step_progress * 0.5) | |
| self.callback(stage, overall_progress, 1.0) | |
| return result | |
| # Use custom hook for pyannote progress with callback forwarding | |
| with CustomProgressHook(callback=progress_callback) as hook: | |
| diarization = pipeline( | |
| audio_dict, min_speakers=min_speakers, max_speakers=max_speakers, hook=hook | |
| ) | |
| if progress_callback: | |
| progress_callback("Speaker detection complete", 0.8, 1.0) | |
| # Count speakers by iterating through speaker_diarization | |
| speakers = set() | |
| for turn, speaker in diarization.speaker_diarization: | |
| speakers.add(speaker) | |
| logger.info(f"Detected {len(speakers)} speakers: {', '.join(sorted(speakers))}") | |
| return diarization | |
| finally: | |
| # Clean up | |
| del pipeline | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| class SpeakerSeparationService: | |
| """ | |
| Service for speaker diarization and separation. | |
| Uses pyannote.audio for speaker diarization to identify and separate | |
| individual speakers from multi-speaker audio files. | |
| """ | |
| def __init__(self, hf_token: Optional[str] = None): | |
| """ | |
| Initialize speaker separation service. | |
| Args: | |
| hf_token: HuggingFace API token (required for pyannote models) | |
| If None, will try to get from HF_TOKEN env var | |
| Raises: | |
| ValueError: If HuggingFace token not provided | |
| """ | |
| if hf_token is None: | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError( | |
| "HuggingFace token required. Set HF_TOKEN environment " | |
| "variable or pass hf_token parameter." | |
| ) | |
| self.hf_token = hf_token | |
| def convert_to_wav(self, input_path: str, sample_rate: int = 16000) -> str: | |
| """ | |
| Convert M4A/AAC to WAV for pyannote processing. | |
| Args: | |
| input_path: Path to M4A file | |
| sample_rate: Target sample rate (default: 16000 for pyannote) | |
| Returns: | |
| Path to converted WAV file | |
| """ | |
| return convert_m4a_to_wav(input_path, sample_rate=sample_rate) | |
| def separate_speakers( | |
| self, | |
| audio_path: str, | |
| min_speakers: int = 2, | |
| max_speakers: int = 5, | |
| progress_callback: Optional[Callable] = None, | |
| ): | |
| """ | |
| Perform speaker diarization on audio file. | |
| Args: | |
| audio_path: Path to audio file (M4A or WAV) | |
| min_speakers: Minimum number of speakers to detect | |
| max_speakers: Maximum number of speakers to detect | |
| progress_callback: Optional callback for progress updates | |
| Returns: | |
| Diarization result from pyannote | |
| Raises: | |
| AudioIOError: If file cannot be read | |
| ValueError: If parameters are invalid | |
| """ | |
| if min_speakers > max_speakers: | |
| raise ValueError( | |
| f"min_speakers ({min_speakers}) cannot exceed max_speakers ({max_speakers})" | |
| ) | |
| # Convert M4A to WAV if needed | |
| audio_path = Path(audio_path) | |
| if not audio_path.exists(): | |
| raise AudioIOError(f"Audio file not found: {audio_path}") | |
| if audio_path.suffix.lower() in [".m4a", ".aac", ".mp4"]: | |
| logger.info(f"Converting {audio_path.name} to WAV for processing...") | |
| audio_path = Path(self.convert_to_wav(str(audio_path))) | |
| # Run diarization with progress reporting | |
| logger.info(f"Performing speaker diarization (min={min_speakers}, max={max_speakers})...") | |
| if progress_callback: | |
| progress_callback("Starting speaker detection", 0.0, 1.0) | |
| # Load audio ourselves and pass as dict to avoid torchcodec issues | |
| audio_data, sr = read_audio(str(audio_path), target_sr=16000) | |
| audio_dict = { | |
| "waveform": torch.from_numpy(audio_data).unsqueeze(0), # Add channel dimension | |
| "sample_rate": sr, | |
| } | |
| # Call the module-level GPU function (avoids pickling self) | |
| diarization = _run_diarization_on_gpu( | |
| audio_dict=audio_dict, | |
| hf_token=self.hf_token, | |
| min_speakers=min_speakers, | |
| max_speakers=max_speakers, | |
| progress_callback=progress_callback, | |
| ) | |
| return diarization | |
| def extract_speaker_segments(self, diarization, speaker_id: str) -> List[AudioSegment]: | |
| """ | |
| Extract audio segments for a specific speaker. | |
| Args: | |
| diarization: Diarization result from pyannote | |
| speaker_id: Speaker ID to extract (e.g., "SPEAKER_00") | |
| Returns: | |
| List of AudioSegment objects for this speaker | |
| """ | |
| segments = [] | |
| # pyannote.audio 4.0 API - iterate over speaker_diarization | |
| for turn, speaker in diarization.speaker_diarization: | |
| if speaker == speaker_id: | |
| audio_segment = AudioSegment( | |
| start_time=turn.start, | |
| end_time=turn.end, | |
| speaker_id=speaker_id, | |
| confidence=1.0, # pyannote doesn't provide per-segment confidence | |
| segment_type=SegmentType.SPEECH, | |
| ) | |
| segments.append(audio_segment) | |
| logger.debug(f"Extracted {len(segments)} segments for {speaker_id}") | |
| return segments | |
| def export_speaker_audio( | |
| self, | |
| audio: np.ndarray, | |
| sample_rate: int, | |
| output_path: str, | |
| output_sample_rate: int = 44100, | |
| bitrate: str = "192k", | |
| ) -> str: | |
| """ | |
| Export speaker audio to M4A format. | |
| Args: | |
| audio: Audio array | |
| sample_rate: Input sample rate | |
| output_path: Output M4A file path | |
| output_sample_rate: Output sample rate (default: 44100) | |
| bitrate: Output bitrate (default: "192k") | |
| Returns: | |
| Path to exported M4A file | |
| """ | |
| output_path = Path(output_path) | |
| # Create output directory | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # First write to temporary WAV | |
| temp_wav = output_path.with_suffix(".temp.wav") | |
| write_audio(str(temp_wav), audio, sample_rate) | |
| # Convert to M4A | |
| m4a_path = convert_wav_to_m4a( | |
| str(temp_wav), str(output_path), sample_rate=output_sample_rate, bitrate=bitrate | |
| ) | |
| # Clean up temp file | |
| temp_wav.unlink() | |
| logger.info(f"Exported speaker audio to {output_path.name}") | |
| return m4a_path | |
| def generate_separation_report( | |
| self, | |
| input_file: str, | |
| speakers: List[str], | |
| segments: Dict[str, List[AudioSegment]], | |
| processing_time: float, | |
| output_files: List[Dict], | |
| input_duration: float, | |
| ) -> Dict: | |
| """ | |
| Generate separation report JSON. | |
| Args: | |
| input_file: Input file path | |
| speakers: List of speaker IDs | |
| segments: Dict mapping speaker IDs to their segments | |
| processing_time: Processing time in seconds | |
| output_files: List of output file information | |
| input_duration: Input audio duration in seconds | |
| Returns: | |
| Report dictionary | |
| """ | |
| # Calculate quality metrics | |
| total_segments = sum(len(segs) for segs in segments.values()) | |
| avg_confidence = sum(seg.confidence for segs in segments.values() for seg in segs) / max( | |
| total_segments, 1 | |
| ) | |
| # Count overlapping segments | |
| overlapping = 0 | |
| all_segs = [seg for segs in segments.values() for seg in segs] | |
| for i, seg1 in enumerate(all_segs): | |
| for seg2 in all_segs[i + 1 :]: | |
| if seg1.overlaps_with(seg2): | |
| overlapping += 1 | |
| report = { | |
| "input_file": str(input_file), | |
| "input_duration_seconds": input_duration, | |
| "speakers_detected": len(speakers), | |
| "processing_time_seconds": processing_time, | |
| "output_files": output_files, | |
| "overlapping_segments": overlapping, | |
| "quality_metrics": { | |
| "average_confidence": round(avg_confidence, 3), | |
| "total_segments": total_segments, | |
| "low_confidence_segments": sum( | |
| 1 for segs in segments.values() for seg in segs if seg.confidence < 0.7 | |
| ), | |
| }, | |
| } | |
| return report | |
| def separate_and_export( | |
| self, | |
| input_file: str, | |
| output_dir: str, | |
| min_speakers: int = 2, | |
| max_speakers: int = 5, | |
| output_format: str = "m4a", | |
| sample_rate: int = 44100, | |
| bitrate: str = "192k", | |
| progress_callback: Optional[Callable] = None, | |
| ) -> Dict: | |
| """ | |
| Complete workflow: separate speakers and export to individual files. | |
| Args: | |
| input_file: Input M4A audio file | |
| output_dir: Output directory for separated files | |
| min_speakers: Minimum speakers to detect | |
| max_speakers: Maximum speakers to detect | |
| output_format: Output format - m4a, wav, or mp3 (default: "m4a") | |
| sample_rate: Output sample rate (default: 44100) | |
| bitrate: Output bitrate (default: "192k") | |
| progress_callback: Optional progress callback | |
| Returns: | |
| Separation report dictionary or ErrorReport on failure | |
| """ | |
| start_time = time.time() | |
| try: | |
| input_file = Path(input_file) | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Get input duration | |
| input_duration = get_audio_duration(str(input_file)) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize speaker separation: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to initialize speaker separation: {e}", | |
| "error_type": "audio_io", | |
| } | |
| return error_report | |
| try: | |
| # Perform speaker diarization | |
| if progress_callback: | |
| progress_callback("Loading audio", 0.1, 1.0) | |
| # Note: progress_callback cannot be passed due to ZeroGPU pickling constraints | |
| diarization = self.separate_speakers( | |
| str(input_file), | |
| min_speakers=min_speakers, | |
| max_speakers=max_speakers, | |
| progress_callback=None, # Cannot pass callback to avoid pickling errors | |
| ) | |
| except Exception as e: | |
| logger.error(f"Speaker diarization failed: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Speaker diarization failed: {e}", | |
| "error_type": "processing", | |
| } | |
| return error_report | |
| try: | |
| # Get unique speakers by iterating through speaker_diarization | |
| speakers = set() | |
| for turn, speaker in diarization.speaker_diarization: | |
| speakers.add(speaker) | |
| speakers = sorted(list(speakers)) | |
| # Extract segments for each speaker | |
| segments = {} | |
| for speaker_id in speakers: | |
| segments[speaker_id] = self.extract_speaker_segments(diarization, speaker_id) | |
| # Load full audio for extraction | |
| if progress_callback: | |
| progress_callback("Performing speaker diarization", 0.2, 1.0) | |
| # Convert to WAV for processing if needed | |
| wav_path = input_file | |
| if input_file.suffix.lower() in [".m4a", ".aac", ".mp4"]: | |
| wav_path = Path(self.convert_to_wav(str(input_file), sample_rate=sample_rate)) | |
| audio, sr = read_audio(str(wav_path), target_sr=sample_rate) | |
| except Exception as e: | |
| logger.error(f"Failed to load and process audio: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to load and process audio: {e}", | |
| "error_type": "audio_io", | |
| } | |
| return error_report | |
| try: | |
| # Export each speaker | |
| output_files = [] | |
| for i, speaker_id in enumerate(speakers): | |
| if progress_callback: | |
| # Progress from 0.8 to 1.0 for speaker exports | |
| export_progress = 0.8 + (0.2 * (i + 1) / len(speakers)) | |
| progress_callback( | |
| f"Exporting speaker {i + 1}/{len(speakers)}", export_progress, 1.0 | |
| ) | |
| # Extract and concatenate all segments for this speaker | |
| speaker_segments = segments[speaker_id] | |
| speaker_audio_parts = [] | |
| for segment in speaker_segments: | |
| segment_audio = extract_segment(audio, sr, segment.start_time, segment.end_time) | |
| speaker_audio_parts.append(segment_audio) | |
| # Concatenate segments | |
| if speaker_audio_parts: | |
| speaker_audio = np.concatenate(speaker_audio_parts) | |
| # Export to M4A | |
| output_file = output_dir / f"speaker_{i:02d}.m4a" | |
| self.export_speaker_audio( | |
| speaker_audio, | |
| sr, | |
| str(output_file), | |
| output_sample_rate=sample_rate, | |
| bitrate=bitrate, | |
| ) | |
| output_files.append( | |
| { | |
| "speaker_id": speaker_id, | |
| "file": str(output_file), | |
| "duration": len(speaker_audio) / sr, | |
| "segments_count": len(speaker_segments), | |
| } | |
| ) | |
| # Generate and save report | |
| processing_time = time.time() - start_time | |
| report = self.generate_separation_report( | |
| input_file=str(input_file), | |
| speakers=speakers, | |
| segments=segments, | |
| processing_time=processing_time, | |
| output_files=output_files, | |
| input_duration=input_duration, | |
| ) | |
| # Write report JSON | |
| report_file = output_dir / "separation_report.json" | |
| with open(report_file, "w") as f: | |
| json.dump(report, f, indent=2) | |
| logger.info(f"Separation complete: {len(speakers)} speakers in {processing_time:.1f}s") | |
| if progress_callback: | |
| progress_callback("Complete", 1.0, 1.0) | |
| return report | |
| except Exception as e: | |
| logger.error(f"Failed to export speakers: {e}") | |
| error_report: ErrorReport = { | |
| "status": "failed", | |
| "error": f"Failed to export speakers: {e}", | |
| "error_type": "processing", | |
| } | |
| return error_report | |