Spaces:
Running
Running
| import os | |
| import torch | |
| from typing import Dict, List, Optional | |
| from pyannote.audio import Pipeline | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class SpeakerDiarizer: | |
| """Handles speaker diarization using pyannote.audio""" | |
| def __init__(self, hf_token: Optional[str] = None): | |
| """ | |
| Initialize speaker diarization pipeline | |
| Args: | |
| hf_token: Hugging Face access token (required for pyannote models) | |
| """ | |
| self.hf_token = hf_token or os.getenv('HF_TOKEN') | |
| if not self.hf_token: | |
| logger.warning("No HF_TOKEN provided. Diarization may fail.") | |
| self.pipeline = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_pipeline(self, progress_callback=None): | |
| """Load the diarization pipeline""" | |
| if progress_callback: | |
| progress_callback("Loading speaker diarization model...") | |
| try: | |
| self.pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=self.hf_token | |
| ) | |
| # Move to GPU if available | |
| if self.device == "cuda": | |
| self.pipeline.to(torch.device("cuda")) | |
| if progress_callback: | |
| progress_callback("Diarization model loaded successfully") | |
| logger.info(f"Diarization pipeline loaded on {self.device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load diarization pipeline: {e}") | |
| raise Exception( | |
| f"Failed to load diarization model. " | |
| f"Make sure you have accepted the terms at: " | |
| f"https://huggingface.co/pyannote/speaker-diarization-3.1 " | |
| f"and provided a valid HF_TOKEN. Error: {str(e)}" | |
| ) | |
| def diarize(self, audio_path: str, progress_callback=None) -> Dict: | |
| """ | |
| Perform speaker diarization on audio file | |
| Args: | |
| audio_path: Path to audio file | |
| progress_callback: Optional callback for progress updates | |
| Returns: | |
| Dictionary mapping time segments to speaker labels | |
| """ | |
| if self.pipeline is None: | |
| self.load_pipeline(progress_callback) | |
| if progress_callback: | |
| progress_callback("Analyzing speakers in audio...") | |
| try: | |
| # Run diarization | |
| diarization = self.pipeline(audio_path) | |
| # Convert to dictionary of segments | |
| segments = [] | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| segments.append({ | |
| 'start': turn.start, | |
| 'end': turn.end, | |
| 'speaker': speaker | |
| }) | |
| if progress_callback: | |
| num_speakers = len(set(seg['speaker'] for seg in segments)) | |
| progress_callback(f"Diarization complete. Found {num_speakers} speakers") | |
| logger.info(f"Diarization found {len(segments)} segments") | |
| return {'segments': segments} | |
| except Exception as e: | |
| logger.error(f"Diarization failed: {e}") | |
| raise Exception(f"Speaker diarization failed: {str(e)}") | |
| def align_with_transcription( | |
| self, | |
| diarization_result: Dict, | |
| transcription_result: Dict, | |
| progress_callback=None | |
| ) -> Dict[int, str]: | |
| """ | |
| Align speaker labels with transcription chunks | |
| Args: | |
| diarization_result: Result from diarize() | |
| transcription_result: Result from transcription | |
| progress_callback: Optional callback for progress updates | |
| Returns: | |
| Dictionary mapping chunk index to speaker label | |
| """ | |
| if progress_callback: | |
| progress_callback("Aligning speakers with transcription...") | |
| speaker_labels = {} | |
| diarization_segments = diarization_result.get('segments', []) | |
| transcription_chunks = transcription_result.get('chunks', []) | |
| for chunk_idx, chunk in enumerate(transcription_chunks): | |
| timestamp = chunk.get('timestamp', (None, None)) | |
| if timestamp[0] is None: | |
| continue | |
| chunk_start = timestamp[0] | |
| chunk_end = timestamp[1] if timestamp[1] is not None else chunk_start + 1.0 | |
| # Find overlapping speaker segments | |
| chunk_mid = (chunk_start + chunk_end) / 2 | |
| best_speaker = None | |
| best_overlap = 0 | |
| for seg in diarization_segments: | |
| seg_start = seg['start'] | |
| seg_end = seg['end'] | |
| # Check if chunk midpoint is in this segment | |
| if seg_start <= chunk_mid <= seg_end: | |
| best_speaker = seg['speaker'] | |
| break | |
| # Calculate overlap | |
| overlap_start = max(chunk_start, seg_start) | |
| overlap_end = min(chunk_end, seg_end) | |
| overlap = max(0, overlap_end - overlap_start) | |
| if overlap > best_overlap: | |
| best_overlap = overlap | |
| best_speaker = seg['speaker'] | |
| if best_speaker: | |
| speaker_labels[chunk_idx] = best_speaker | |
| if progress_callback: | |
| progress_callback("Speaker alignment complete") | |
| logger.info(f"Aligned {len(speaker_labels)} chunks with speakers") | |
| return speaker_labels | |
| def is_available() -> bool: | |
| """Check if diarization is available (HF_TOKEN set)""" | |
| return os.getenv('HF_TOKEN') is not None | |