Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Pyannote Speaker Diarization Wrapper | |
| Optimized for accuracy and performance | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Optional, Tuple | |
| import time | |
| from pathlib import Path | |
| class SpeakerDiarization: | |
| """ | |
| Production-ready Pyannote speaker diarization wrapper. | |
| Features: | |
| - State-of-the-art speaker diarization | |
| - GPU acceleration support | |
| - Configurable parameters for accuracy/speed tradeoff | |
| - Overlap detection | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "pyannote/speaker-diarization-3.1", | |
| use_auth_token: Optional[str] = None, | |
| token: Optional[str] = None, | |
| device: Optional[str] = None, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: Optional[int] = None, | |
| max_speakers: Optional[int] = None | |
| ): | |
| """ | |
| Initialize speaker diarization pipeline. | |
| Args: | |
| model_name: Hugging Face model name | |
| use_auth_token: (Deprecated) Hugging Face authentication token | |
| token: Hugging Face authentication token (new parameter name) | |
| device: Device to use ('cuda' or 'cpu') | |
| num_speakers: Fixed number of speakers (if known) | |
| min_speakers: Minimum number of speakers | |
| max_speakers: Maximum number of speakers | |
| """ | |
| self.model_name = model_name | |
| self.num_speakers = num_speakers | |
| self.min_speakers = min_speakers | |
| self.max_speakers = max_speakers | |
| # Handle both old and new parameter names | |
| auth_token = token or use_auth_token | |
| # Set device | |
| if device is None: | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| self.device = torch.device(device) | |
| # Load pipeline | |
| self.pipeline = self._load_pipeline(auth_token) | |
| print(f"✓ Speaker diarization initialized on {self.device}") | |
| def _load_pipeline(self, auth_token: Optional[str]): | |
| """Load Pyannote diarization pipeline.""" | |
| from pyannote.audio import Pipeline | |
| try: | |
| # Use 'token' parameter for pyannote.audio 4.0+ | |
| pipeline = Pipeline.from_pretrained( | |
| self.model_name, | |
| token=auth_token | |
| ) | |
| # Move to device | |
| pipeline.to(self.device) | |
| return pipeline | |
| except Exception as e: | |
| print(f"❌ Error loading pipeline: {e}") | |
| print("Make sure you have:") | |
| print("1. Accepted model conditions at https://huggingface.co/pyannote/speaker-diarization-3.1") | |
| print("2. Valid HF token from https://huggingface.co/settings/tokens") | |
| raise | |
| def process_file( | |
| self, | |
| audio_path: str, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: Optional[int] = None, | |
| max_speakers: Optional[int] = None | |
| ) -> Tuple[List[Dict], float, Dict]: | |
| """ | |
| Process an audio file and return speaker segments. | |
| Args: | |
| audio_path: Path to audio file | |
| num_speakers: Override number of speakers | |
| min_speakers: Override minimum speakers | |
| max_speakers: Override maximum speakers | |
| Returns: | |
| Tuple of (segments, processing_time_ms, metadata) | |
| """ | |
| # Use instance defaults if not provided | |
| num_speakers = num_speakers or self.num_speakers | |
| min_speakers = min_speakers or self.min_speakers | |
| max_speakers = max_speakers or self.max_speakers | |
| # Prepare parameters | |
| params = {} | |
| if num_speakers is not None: | |
| params['num_speakers'] = num_speakers | |
| if min_speakers is not None: | |
| params['min_speakers'] = min_speakers | |
| if max_speakers is not None: | |
| params['max_speakers'] = max_speakers | |
| # Process | |
| start_time = time.time() | |
| diarization = self.pipeline(audio_path, **params) | |
| processing_time = (time.time() - start_time) * 1000 # Convert to ms | |
| # Extract segments | |
| segments = [] | |
| speakers = set() | |
| # Handle different output formats from pyannote.audio | |
| # Version 4.0+ returns DiarizeOutput, earlier versions return Annotation | |
| if hasattr(diarization, 'speaker_diarization'): | |
| # pyannote.audio 4.0+ format - DiarizeOutput object | |
| annotation = diarization.speaker_diarization | |
| elif hasattr(diarization, 'itertracks'): | |
| # pyannote.audio 3.x format - Annotation object | |
| annotation = diarization | |
| else: | |
| raise ValueError(f"Unknown diarization output format: {type(diarization)}") | |
| # Extract segments from annotation | |
| for turn, _, speaker in annotation.itertracks(yield_label=True): | |
| segments.append({ | |
| 'start': turn.start, | |
| 'end': turn.end, | |
| 'speaker': speaker, | |
| 'duration': turn.end - turn.start | |
| }) | |
| speakers.add(speaker) | |
| # Metadata | |
| metadata = { | |
| 'num_speakers': len(speakers), | |
| 'total_speech_time': sum(seg['duration'] for seg in segments), | |
| 'num_segments': len(segments) | |
| } | |
| return segments, processing_time, metadata | |
| def process_with_vad_segments( | |
| self, | |
| audio_path: str, | |
| vad_segments: List[Dict], | |
| **kwargs | |
| ) -> List[Dict]: | |
| """ | |
| Process audio using VAD segments to optimize diarization. | |
| Args: | |
| audio_path: Path to audio file | |
| vad_segments: List of VAD segments with 'start' and 'end' | |
| **kwargs: Additional parameters for diarization | |
| Returns: | |
| List of speaker segments | |
| """ | |
| # For now, process full file | |
| # TODO: Implement segment-wise processing for optimization | |
| segments, _, _ = self.process_file(audio_path, **kwargs) | |
| # Filter segments to only include VAD regions | |
| filtered_segments = [] | |
| for seg in segments: | |
| # Check if segment overlaps with any VAD segment | |
| for vad_seg in vad_segments: | |
| vad_start = vad_seg['start'] | |
| vad_end = vad_seg['end'] | |
| # Check overlap | |
| if seg['start'] < vad_end and seg['end'] > vad_start: | |
| filtered_segments.append(seg) | |
| break | |
| return filtered_segments | |
| def get_speaker_statistics(self, segments: List[Dict]) -> Dict: | |
| """ | |
| Calculate speaker statistics from segments. | |
| Args: | |
| segments: List of speaker segments | |
| Returns: | |
| Dict with per-speaker statistics | |
| """ | |
| stats = {} | |
| for seg in segments: | |
| speaker = seg['speaker'] | |
| if speaker not in stats: | |
| stats[speaker] = { | |
| 'total_time': 0.0, | |
| 'num_segments': 0, | |
| 'avg_segment_duration': 0.0 | |
| } | |
| stats[speaker]['total_time'] += seg['duration'] | |
| stats[speaker]['num_segments'] += 1 | |
| # Calculate averages | |
| for speaker in stats: | |
| stats[speaker]['avg_segment_duration'] = ( | |
| stats[speaker]['total_time'] / stats[speaker]['num_segments'] | |
| ) | |
| return stats | |
| def format_timeline(self, segments: List[Dict]) -> str: | |
| """ | |
| Format segments as a readable timeline. | |
| Args: | |
| segments: List of speaker segments | |
| Returns: | |
| Formatted timeline string | |
| """ | |
| lines = ["Speaker Timeline:", "=" * 50] | |
| for seg in segments: | |
| line = f"{seg['start']:7.2f}s - {seg['end']:7.2f}s: {seg['speaker']} ({seg['duration']:.2f}s)" | |
| lines.append(line) | |
| return "\n".join(lines) | |
| def calculate_der( | |
| self, | |
| predicted_segments: List[Dict], | |
| reference_segments: List[Dict], | |
| collar: float = 0.25 | |
| ) -> float: | |
| """ | |
| Calculate Diarization Error Rate (DER). | |
| Args: | |
| predicted_segments: Predicted speaker segments | |
| reference_segments: Ground truth segments | |
| collar: Collar size in seconds for forgiveness | |
| Returns: | |
| DER value (0.0-1.0) | |
| """ | |
| # This is a simplified DER calculation | |
| # For production, use pyannote.metrics | |
| try: | |
| from pyannote.metrics.diarization import DiarizationErrorRate | |
| from pyannote.core import Annotation, Segment | |
| # Convert to pyannote format | |
| reference = Annotation() | |
| for seg in reference_segments: | |
| reference[Segment(seg['start'], seg['end'])] = seg['speaker'] | |
| hypothesis = Annotation() | |
| for seg in predicted_segments: | |
| hypothesis[Segment(seg['start'], seg['end'])] = seg['speaker'] | |
| # Calculate DER | |
| metric = DiarizationErrorRate(collar=collar) | |
| der = metric(reference, hypothesis) | |
| return der | |
| except ImportError: | |
| print("⚠️ pyannote.metrics not available, skipping DER calculation") | |
| return -1.0 | |
| def demo(): | |
| """Demo diarization functionality.""" | |
| print("\n" + "="*60) | |
| print("SPEAKER DIARIZATION DEMO") | |
| print("="*60) | |
| print("\n⚠️ This demo requires:") | |
| print("1. Hugging Face account") | |
| print("2. Accepted model conditions at:") | |
| print(" https://huggingface.co/pyannote/speaker-diarization-3.1") | |
| print("3. Valid HF token from:") | |
| print(" https://huggingface.co/settings/tokens") | |
| # Check for token | |
| import os | |
| token = os.environ.get('HF_TOKEN') | |
| if not token: | |
| print("\n❌ No HF_TOKEN found in environment") | |
| print("Set it with: export HF_TOKEN='your_token_here'") | |
| return | |
| try: | |
| # Initialize | |
| diarization = SpeakerDiarization(use_auth_token=token) | |
| print("\n✅ Diarization pipeline loaded successfully") | |
| except Exception as e: | |
| print(f"\n❌ Failed to load pipeline: {e}") | |
| print("\n" + "="*60) | |
| if __name__ == "__main__": | |
| demo() | |