Spaces:
Running
Running
| """ | |
| Speaker Diarization - identify who spoke when. | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Optional | |
| from dataclasses import dataclass, field | |
| from sklearn.cluster import AgglomerativeClustering | |
| class SpeakerSegment: | |
| """A segment of speech from a specific speaker.""" | |
| start: float | |
| end: float | |
| speaker_id: str | |
| def duration(self) -> float: | |
| return self.end - self.start | |
| class SpeakerInfo: | |
| """Information about a speaker.""" | |
| speaker_id: str | |
| total_seconds: float = 0.0 | |
| segments: List[SpeakerSegment] = field(default_factory=list) | |
| embedding: Optional[np.ndarray] = None | |
| def add_segment(self, segment: SpeakerSegment): | |
| self.segments.append(segment) | |
| self.total_seconds += segment.duration | |
| class SpeakerDiarizer: | |
| """Speaker diarization using embedding clustering.""" | |
| def __init__(self, device: str = None): | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self._embedding_model = None | |
| def embedding_model(self): | |
| """Lazy load embedding model.""" | |
| if self._embedding_model is None: | |
| from speechbrain.inference.speaker import SpeakerRecognition | |
| import os | |
| model_dir = os.environ.get("MODEL_DIR", "pretrained_models") | |
| self._embedding_model = SpeakerRecognition.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir=os.path.join(model_dir, "spkrec"), | |
| run_opts={"device": self.device} | |
| ) | |
| return self._embedding_model | |
| def diarize(self, audio_path: str, | |
| speech_segments: List = None, | |
| window_size: float = 2.0, | |
| hop_size: float = 0.5, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: int = 1, | |
| max_speakers: int = 5) -> Dict[str, SpeakerInfo]: | |
| """ | |
| Perform speaker diarization. | |
| Args: | |
| audio_path: Path to audio file | |
| speech_segments: Optional list of speech segments (from VAD) | |
| window_size: Window size for embedding extraction | |
| hop_size: Hop size between windows | |
| num_speakers: Known number of speakers (None to estimate) | |
| min_speakers: Minimum speakers to detect | |
| max_speakers: Maximum speakers to detect | |
| Returns: | |
| Dict mapping speaker_id to SpeakerInfo | |
| """ | |
| import torchaudio | |
| # Load audio (use soundfile backend to avoid torchcodec dependency) | |
| waveform, sample_rate = torchaudio.load(audio_path, backend="soundfile") | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| duration = waveform.shape[1] / sample_rate | |
| # Extract embeddings for windows | |
| windows = [] | |
| embeddings = [] | |
| current = 0.0 | |
| while current + window_size <= duration: | |
| start_sample = int(current * sample_rate) | |
| end_sample = int((current + window_size) * sample_rate) | |
| window_audio = waveform[:, start_sample:end_sample] | |
| # Check if this window has speech (if VAD provided) | |
| has_speech = True | |
| if speech_segments: | |
| has_speech = any( | |
| s.start <= current + window_size/2 <= s.end | |
| for s in speech_segments | |
| ) | |
| if has_speech and window_audio.shape[1] > 0: | |
| # Extract embedding | |
| emb = self.embedding_model.encode_batch(window_audio) | |
| emb = emb.squeeze().cpu().numpy() | |
| windows.append({'start': current, 'end': current + window_size}) | |
| embeddings.append(emb) | |
| current += hop_size | |
| if len(embeddings) < 2: | |
| # Not enough data for clustering | |
| speaker_info = SpeakerInfo(speaker_id="speaker_A") | |
| for seg in (speech_segments or []): | |
| speaker_info.add_segment(SpeakerSegment( | |
| start=seg.start, end=seg.end, speaker_id="speaker_A" | |
| )) | |
| if embeddings: | |
| speaker_info.embedding = embeddings[0] | |
| return {"speaker_A": speaker_info} | |
| embeddings_array = np.array(embeddings) | |
| # Cluster embeddings | |
| if num_speakers is None: | |
| # Estimate number of speakers | |
| clustering = AgglomerativeClustering( | |
| n_clusters=None, | |
| distance_threshold=0.7, | |
| metric='cosine', | |
| linkage='average' | |
| ) | |
| else: | |
| clustering = AgglomerativeClustering( | |
| n_clusters=num_speakers, | |
| metric='cosine', | |
| linkage='average' | |
| ) | |
| labels = clustering.fit_predict(embeddings_array) | |
| # Clamp number of speakers | |
| unique_labels = np.unique(labels) | |
| if len(unique_labels) > max_speakers: | |
| # Re-cluster with max speakers | |
| clustering = AgglomerativeClustering( | |
| n_clusters=max_speakers, | |
| metric='cosine', | |
| linkage='average' | |
| ) | |
| labels = clustering.fit_predict(embeddings_array) | |
| unique_labels = np.unique(labels) | |
| # Build speaker info | |
| speakers = {} | |
| speaker_names = ['speaker_A', 'speaker_B', 'speaker_C', 'speaker_D', 'speaker_E'] | |
| for label in unique_labels: | |
| speaker_id = speaker_names[label] if label < len(speaker_names) else f"speaker_{label}" | |
| speakers[speaker_id] = SpeakerInfo(speaker_id=speaker_id) | |
| # Calculate mean embedding for this speaker | |
| mask = labels == label | |
| speaker_embeddings = embeddings_array[mask] | |
| speakers[speaker_id].embedding = np.mean(speaker_embeddings, axis=0) | |
| # Assign windows to speakers | |
| for i, (window, label) in enumerate(zip(windows, labels)): | |
| speaker_id = speaker_names[label] if label < len(speaker_names) else f"speaker_{label}" | |
| segment = SpeakerSegment( | |
| start=window['start'], | |
| end=window['end'], | |
| speaker_id=speaker_id | |
| ) | |
| speakers[speaker_id].add_segment(segment) | |
| # Sort by total speech time (main speaker first) | |
| speakers = dict(sorted( | |
| speakers.items(), | |
| key=lambda x: x[1].total_seconds, | |
| reverse=True | |
| )) | |
| return speakers | |
| def get_main_speaker(self, speakers: Dict[str, SpeakerInfo]) -> Optional[SpeakerInfo]: | |
| """Get the speaker with most speech time.""" | |
| if not speakers: | |
| return None | |
| return next(iter(speakers.values())) | |
| def get_additional_speakers(self, speakers: Dict[str, SpeakerInfo]) -> List[SpeakerInfo]: | |
| """Get all speakers except the main one.""" | |
| items = list(speakers.values()) | |
| return items[1:] if len(items) > 1 else [] | |