#!/usr/bin/env python3 """ Pyannote Speaker Diarization Wrapper Optimized for accuracy and performance """ import torch import numpy as np from typing import List, Dict, Optional, Tuple import time from pathlib import Path class SpeakerDiarization: """ Production-ready Pyannote speaker diarization wrapper. Features: - State-of-the-art speaker diarization - GPU acceleration support - Configurable parameters for accuracy/speed tradeoff - Overlap detection """ def __init__( self, model_name: str = "pyannote/speaker-diarization-3.1", use_auth_token: Optional[str] = None, token: Optional[str] = None, device: Optional[str] = None, num_speakers: Optional[int] = None, min_speakers: Optional[int] = None, max_speakers: Optional[int] = None ): """ Initialize speaker diarization pipeline. Args: model_name: Hugging Face model name use_auth_token: (Deprecated) Hugging Face authentication token token: Hugging Face authentication token (new parameter name) device: Device to use ('cuda' or 'cpu') num_speakers: Fixed number of speakers (if known) min_speakers: Minimum number of speakers max_speakers: Maximum number of speakers """ self.model_name = model_name self.num_speakers = num_speakers self.min_speakers = min_speakers self.max_speakers = max_speakers # Handle both old and new parameter names auth_token = token or use_auth_token # Set device if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) # Load pipeline self.pipeline = self._load_pipeline(auth_token) print(f"✓ Speaker diarization initialized on {self.device}") def _load_pipeline(self, auth_token: Optional[str]): """Load Pyannote diarization pipeline.""" from pyannote.audio import Pipeline try: # Use 'token' parameter for pyannote.audio 4.0+ pipeline = Pipeline.from_pretrained( self.model_name, token=auth_token ) # Move to device pipeline.to(self.device) return pipeline except Exception as e: print(f"❌ Error loading pipeline: {e}") print("Make sure you have:") print("1. Accepted model conditions at https://huggingface.co/pyannote/speaker-diarization-3.1") print("2. Valid HF token from https://huggingface.co/settings/tokens") raise def process_file( self, audio_path: str, num_speakers: Optional[int] = None, min_speakers: Optional[int] = None, max_speakers: Optional[int] = None ) -> Tuple[List[Dict], float, Dict]: """ Process an audio file and return speaker segments. Args: audio_path: Path to audio file num_speakers: Override number of speakers min_speakers: Override minimum speakers max_speakers: Override maximum speakers Returns: Tuple of (segments, processing_time_ms, metadata) """ # Use instance defaults if not provided num_speakers = num_speakers or self.num_speakers min_speakers = min_speakers or self.min_speakers max_speakers = max_speakers or self.max_speakers # Prepare parameters params = {} if num_speakers is not None: params['num_speakers'] = num_speakers if min_speakers is not None: params['min_speakers'] = min_speakers if max_speakers is not None: params['max_speakers'] = max_speakers # Process start_time = time.time() diarization = self.pipeline(audio_path, **params) processing_time = (time.time() - start_time) * 1000 # Convert to ms # Extract segments segments = [] speakers = set() # Handle different output formats from pyannote.audio # Version 4.0+ returns DiarizeOutput, earlier versions return Annotation if hasattr(diarization, 'speaker_diarization'): # pyannote.audio 4.0+ format - DiarizeOutput object annotation = diarization.speaker_diarization elif hasattr(diarization, 'itertracks'): # pyannote.audio 3.x format - Annotation object annotation = diarization else: raise ValueError(f"Unknown diarization output format: {type(diarization)}") # Extract segments from annotation for turn, _, speaker in annotation.itertracks(yield_label=True): segments.append({ 'start': turn.start, 'end': turn.end, 'speaker': speaker, 'duration': turn.end - turn.start }) speakers.add(speaker) # Metadata metadata = { 'num_speakers': len(speakers), 'total_speech_time': sum(seg['duration'] for seg in segments), 'num_segments': len(segments) } return segments, processing_time, metadata def process_with_vad_segments( self, audio_path: str, vad_segments: List[Dict], **kwargs ) -> List[Dict]: """ Process audio using VAD segments to optimize diarization. Args: audio_path: Path to audio file vad_segments: List of VAD segments with 'start' and 'end' **kwargs: Additional parameters for diarization Returns: List of speaker segments """ # For now, process full file # TODO: Implement segment-wise processing for optimization segments, _, _ = self.process_file(audio_path, **kwargs) # Filter segments to only include VAD regions filtered_segments = [] for seg in segments: # Check if segment overlaps with any VAD segment for vad_seg in vad_segments: vad_start = vad_seg['start'] vad_end = vad_seg['end'] # Check overlap if seg['start'] < vad_end and seg['end'] > vad_start: filtered_segments.append(seg) break return filtered_segments def get_speaker_statistics(self, segments: List[Dict]) -> Dict: """ Calculate speaker statistics from segments. Args: segments: List of speaker segments Returns: Dict with per-speaker statistics """ stats = {} for seg in segments: speaker = seg['speaker'] if speaker not in stats: stats[speaker] = { 'total_time': 0.0, 'num_segments': 0, 'avg_segment_duration': 0.0 } stats[speaker]['total_time'] += seg['duration'] stats[speaker]['num_segments'] += 1 # Calculate averages for speaker in stats: stats[speaker]['avg_segment_duration'] = ( stats[speaker]['total_time'] / stats[speaker]['num_segments'] ) return stats def format_timeline(self, segments: List[Dict]) -> str: """ Format segments as a readable timeline. Args: segments: List of speaker segments Returns: Formatted timeline string """ lines = ["Speaker Timeline:", "=" * 50] for seg in segments: line = f"{seg['start']:7.2f}s - {seg['end']:7.2f}s: {seg['speaker']} ({seg['duration']:.2f}s)" lines.append(line) return "\n".join(lines) def calculate_der( self, predicted_segments: List[Dict], reference_segments: List[Dict], collar: float = 0.25 ) -> float: """ Calculate Diarization Error Rate (DER). Args: predicted_segments: Predicted speaker segments reference_segments: Ground truth segments collar: Collar size in seconds for forgiveness Returns: DER value (0.0-1.0) """ # This is a simplified DER calculation # For production, use pyannote.metrics try: from pyannote.metrics.diarization import DiarizationErrorRate from pyannote.core import Annotation, Segment # Convert to pyannote format reference = Annotation() for seg in reference_segments: reference[Segment(seg['start'], seg['end'])] = seg['speaker'] hypothesis = Annotation() for seg in predicted_segments: hypothesis[Segment(seg['start'], seg['end'])] = seg['speaker'] # Calculate DER metric = DiarizationErrorRate(collar=collar) der = metric(reference, hypothesis) return der except ImportError: print("⚠️ pyannote.metrics not available, skipping DER calculation") return -1.0 def demo(): """Demo diarization functionality.""" print("\n" + "="*60) print("SPEAKER DIARIZATION DEMO") print("="*60) print("\n⚠️ This demo requires:") print("1. Hugging Face account") print("2. Accepted model conditions at:") print(" https://huggingface.co/pyannote/speaker-diarization-3.1") print("3. Valid HF token from:") print(" https://huggingface.co/settings/tokens") # Check for token import os token = os.environ.get('HF_TOKEN') if not token: print("\n❌ No HF_TOKEN found in environment") print("Set it with: export HF_TOKEN='your_token_here'") return try: # Initialize diarization = SpeakerDiarization(use_auth_token=token) print("\n✅ Diarization pipeline loaded successfully") except Exception as e: print(f"\n❌ Failed to load pipeline: {e}") print("\n" + "="*60) if __name__ == "__main__": demo()