import os import torch from typing import Dict, List, Optional from pyannote.audio import Pipeline import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SpeakerDiarizer: """Handles speaker diarization using pyannote.audio""" def __init__(self, hf_token: Optional[str] = None): """ Initialize speaker diarization pipeline Args: hf_token: Hugging Face access token (required for pyannote models) """ self.hf_token = hf_token or os.getenv('HF_TOKEN') if not self.hf_token: logger.warning("No HF_TOKEN provided. Diarization may fail.") self.pipeline = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_pipeline(self, progress_callback=None): """Load the diarization pipeline""" if progress_callback: progress_callback("Loading speaker diarization model...") try: self.pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=self.hf_token ) # Move to GPU if available if self.device == "cuda": self.pipeline.to(torch.device("cuda")) if progress_callback: progress_callback("Diarization model loaded successfully") logger.info(f"Diarization pipeline loaded on {self.device}") except Exception as e: logger.error(f"Failed to load diarization pipeline: {e}") raise Exception( f"Failed to load diarization model. " f"Make sure you have accepted the terms at: " f"https://huggingface.co/pyannote/speaker-diarization-3.1 " f"and provided a valid HF_TOKEN. Error: {str(e)}" ) def diarize(self, audio_path: str, progress_callback=None) -> Dict: """ Perform speaker diarization on audio file Args: audio_path: Path to audio file progress_callback: Optional callback for progress updates Returns: Dictionary mapping time segments to speaker labels """ if self.pipeline is None: self.load_pipeline(progress_callback) if progress_callback: progress_callback("Analyzing speakers in audio...") try: # Run diarization diarization = self.pipeline(audio_path) # Convert to dictionary of segments segments = [] for turn, _, speaker in diarization.itertracks(yield_label=True): segments.append({ 'start': turn.start, 'end': turn.end, 'speaker': speaker }) if progress_callback: num_speakers = len(set(seg['speaker'] for seg in segments)) progress_callback(f"Diarization complete. Found {num_speakers} speakers") logger.info(f"Diarization found {len(segments)} segments") return {'segments': segments} except Exception as e: logger.error(f"Diarization failed: {e}") raise Exception(f"Speaker diarization failed: {str(e)}") def align_with_transcription( self, diarization_result: Dict, transcription_result: Dict, progress_callback=None ) -> Dict[int, str]: """ Align speaker labels with transcription chunks Args: diarization_result: Result from diarize() transcription_result: Result from transcription progress_callback: Optional callback for progress updates Returns: Dictionary mapping chunk index to speaker label """ if progress_callback: progress_callback("Aligning speakers with transcription...") speaker_labels = {} diarization_segments = diarization_result.get('segments', []) transcription_chunks = transcription_result.get('chunks', []) for chunk_idx, chunk in enumerate(transcription_chunks): timestamp = chunk.get('timestamp', (None, None)) if timestamp[0] is None: continue chunk_start = timestamp[0] chunk_end = timestamp[1] if timestamp[1] is not None else chunk_start + 1.0 # Find overlapping speaker segments chunk_mid = (chunk_start + chunk_end) / 2 best_speaker = None best_overlap = 0 for seg in diarization_segments: seg_start = seg['start'] seg_end = seg['end'] # Check if chunk midpoint is in this segment if seg_start <= chunk_mid <= seg_end: best_speaker = seg['speaker'] break # Calculate overlap overlap_start = max(chunk_start, seg_start) overlap_end = min(chunk_end, seg_end) overlap = max(0, overlap_end - overlap_start) if overlap > best_overlap: best_overlap = overlap best_speaker = seg['speaker'] if best_speaker: speaker_labels[chunk_idx] = best_speaker if progress_callback: progress_callback("Speaker alignment complete") logger.info(f"Aligned {len(speaker_labels)} chunks with speakers") return speaker_labels @staticmethod def is_available() -> bool: """Check if diarization is available (HF_TOKEN set)""" return os.getenv('HF_TOKEN') is not None