Spaces:
Running
Running
| """ | |
| Speaker Diarization Pipeline | |
| Combines: pyannote diarization (preferred) -> fallback VAD + ECAPA-TDNN + AHC clustering | |
| """ | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, List, Union, BinaryIO | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from loguru import logger | |
| from models.embedder import EcapaTDNNEmbedder | |
| from models.clusterer import SpeakerClusterer | |
| class DiarizationSegment: | |
| start: float | |
| end: float | |
| speaker: str | |
| duration: float = field(init=False) | |
| def __post_init__(self): | |
| self.duration = round(self.end - self.start, 3) | |
| def to_dict(self) -> dict: | |
| return { | |
| "start": round(self.start, 3), | |
| "end": round(self.end, 3), | |
| "duration": self.duration, | |
| "speaker": self.speaker, | |
| } | |
| class DiarizationResult: | |
| segments: List[DiarizationSegment] | |
| num_speakers: int | |
| audio_duration: float | |
| processing_time: float | |
| sample_rate: int | |
| def to_dict(self) -> dict: | |
| speakers = sorted(set(s.speaker for s in self.segments)) | |
| return { | |
| "num_speakers": self.num_speakers, | |
| "audio_duration": round(self.audio_duration, 3), | |
| "processing_time": round(self.processing_time, 3), | |
| "sample_rate": self.sample_rate, | |
| "speakers": speakers, | |
| "segments": [s.to_dict() for s in self.segments], | |
| } | |
| class DiarizationPipeline: | |
| """End-to-end speaker diarization with pyannote-first fallback behavior.""" | |
| SAMPLE_RATE = 16000 | |
| WINDOW_DURATION = 2.0 | |
| WINDOW_STEP = 1.0 | |
| MIN_SEGMENT_DURATION = 0.8 | |
| def __init__( | |
| self, | |
| device: str = "auto", | |
| use_pyannote_vad: bool = True, | |
| use_pyannote_diarization: bool = True, | |
| pyannote_diarization_model: str = "pyannote/speaker-diarization-3.1", | |
| hf_token: Optional[str] = None, | |
| num_speakers: Optional[int] = None, | |
| max_speakers: int = 6, | |
| cache_dir: str = "./model_cache", | |
| ): | |
| self.device = self._resolve_device(device) | |
| self.use_pyannote_vad = use_pyannote_vad | |
| self.use_pyannote_diarization = use_pyannote_diarization | |
| self.pyannote_diarization_model = pyannote_diarization_model | |
| self.hf_token = hf_token | |
| self.num_speakers = num_speakers | |
| self.max_speakers = max_speakers | |
| self.cache_dir = Path(cache_dir) | |
| self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir)) | |
| self.clusterer = SpeakerClusterer(max_speakers=max_speakers, distance_threshold=0.55) | |
| self._vad_pipeline = None | |
| self._full_diar_pipeline = None | |
| logger.info(f"DiarizationPipeline ready | device={self.device}") | |
| def _resolve_device(self, device: str) -> str: | |
| if device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return device | |
| def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor: | |
| if audio.dim() == 1: | |
| return audio | |
| if audio.dim() >= 2: | |
| if audio.shape[0] == 1: | |
| return audio[0] | |
| return audio.mean(dim=0) | |
| return audio.reshape(-1) | |
| def _load_pyannote_pipeline(self, model_id: str): | |
| from pyannote.audio import Pipeline | |
| try: | |
| if self.hf_token: | |
| try: | |
| pipeline = Pipeline.from_pretrained(model_id, use_auth_token=self.hf_token) | |
| except TypeError: | |
| pipeline = Pipeline.from_pretrained(model_id, token=self.hf_token) | |
| else: | |
| pipeline = Pipeline.from_pretrained(model_id) | |
| except TypeError: | |
| pipeline = Pipeline.from_pretrained(model_id) | |
| if pipeline is None: | |
| raise RuntimeError(f"Pipeline.from_pretrained returned None for {model_id}") | |
| try: | |
| pipeline.to(torch.device(self.device)) | |
| except Exception: | |
| pass | |
| return pipeline | |
| def _load_full_diarization(self): | |
| if self._full_diar_pipeline is not None: | |
| return | |
| try: | |
| logger.info(f"Loading pyannote diarization pipeline: {self.pyannote_diarization_model}") | |
| self._full_diar_pipeline = self._load_pyannote_pipeline(self.pyannote_diarization_model) | |
| logger.success("Pyannote speaker diarization pipeline loaded.") | |
| except Exception as e: | |
| logger.warning(f"Could not load pyannote diarization pipeline: {e}.") | |
| self._full_diar_pipeline = "unavailable" | |
| def _load_vad(self): | |
| if self._vad_pipeline is not None: | |
| return | |
| try: | |
| logger.info("Loading pyannote VAD pipeline...") | |
| self._vad_pipeline = self._load_pyannote_pipeline("pyannote/voice-activity-detection") | |
| logger.success("Pyannote VAD loaded.") | |
| except Exception as e: | |
| logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.") | |
| self._vad_pipeline = "energy" | |
| def _merge_named_segments( | |
| self, segments: List[DiarizationSegment], gap_tolerance: float = 0.35 | |
| ) -> List[DiarizationSegment]: | |
| if not segments: | |
| return [] | |
| merged = [segments[0]] | |
| for seg in segments[1:]: | |
| last = merged[-1] | |
| if seg.speaker == last.speaker and seg.start - last.end <= gap_tolerance: | |
| merged[-1] = DiarizationSegment(start=last.start, end=seg.end, speaker=last.speaker) | |
| else: | |
| merged.append(seg) | |
| return merged | |
| def _run_full_pyannote( | |
| self, | |
| audio: Union[str, Path, torch.Tensor], | |
| sample_rate: int, | |
| num_speakers: Optional[int], | |
| audio_duration: float, | |
| t_start: float, | |
| ) -> Optional[DiarizationResult]: | |
| if not self.use_pyannote_diarization: | |
| return None | |
| self._load_full_diarization() | |
| if self._full_diar_pipeline == "unavailable": | |
| return None | |
| tmp_path = None | |
| source = audio | |
| try: | |
| if not isinstance(audio, (str, Path)): | |
| mono = self._to_mono_1d(audio).detach().cpu().float() | |
| wav = mono.unsqueeze(0) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp_path = tmp.name | |
| torchaudio.save(tmp_path, wav, sample_rate) | |
| source = tmp_path | |
| kwargs = {} | |
| if num_speakers is not None: | |
| kwargs["num_speakers"] = int(num_speakers) | |
| diar_output = self._full_diar_pipeline(str(source), **kwargs) | |
| raw_segments = [] | |
| speaker_map = {} | |
| next_id = 0 | |
| for turn, _, speaker in diar_output.itertracks(yield_label=True): | |
| start = float(turn.start) | |
| end = float(turn.end) | |
| if end - start < 0.2: | |
| continue | |
| if speaker not in speaker_map: | |
| speaker_map[speaker] = f"SPEAKER_{next_id:02d}" | |
| next_id += 1 | |
| raw_segments.append( | |
| DiarizationSegment(start=start, end=end, speaker=speaker_map[speaker]) | |
| ) | |
| if not raw_segments: | |
| return None | |
| raw_segments.sort(key=lambda s: (s.start, s.end)) | |
| merged_segments = self._merge_named_segments(raw_segments) | |
| num_unique = len(set(s.speaker for s in merged_segments)) | |
| logger.success( | |
| f"Pyannote diarization complete: {num_unique} speakers, {len(merged_segments)} segments" | |
| ) | |
| return DiarizationResult( | |
| segments=merged_segments, | |
| num_speakers=num_unique, | |
| audio_duration=audio_duration, | |
| processing_time=time.time() - t_start, | |
| sample_rate=sample_rate, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Full pyannote diarization failed: {e}. Falling back to ECAPA+AHC.") | |
| return None | |
| finally: | |
| if tmp_path: | |
| Path(tmp_path).unlink(missing_ok=True) | |
| def _energy_vad( | |
| self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0 | |
| ) -> List[tuple]: | |
| frame_samples = int(frame_duration * self.SAMPLE_RATE) | |
| audio_np = audio.numpy() | |
| frames = [ | |
| audio_np[i: i + frame_samples] | |
| for i in range(0, len(audio_np) - frame_samples, frame_samples) | |
| ] | |
| energies_db = [] | |
| for frame in frames: | |
| rms = np.sqrt(np.mean(frame ** 2) + 1e-10) | |
| energies_db.append(20 * np.log10(rms)) | |
| is_speech = np.array(energies_db) > threshold_db | |
| speech_regions = [] | |
| in_speech = False | |
| start = 0.0 | |
| for i, active in enumerate(is_speech): | |
| t = i * frame_duration | |
| if active and not in_speech: | |
| start = t | |
| in_speech = True | |
| elif not active and in_speech: | |
| speech_regions.append((start, t)) | |
| in_speech = False | |
| if in_speech: | |
| speech_regions.append((start, len(audio_np) / self.SAMPLE_RATE)) | |
| return speech_regions | |
| def _get_speech_regions(self, audio: torch.Tensor) -> List[tuple]: | |
| if self.use_pyannote_vad: | |
| self._load_vad() | |
| if self._vad_pipeline == "energy" or not self.use_pyannote_vad: | |
| return self._energy_vad(audio) | |
| try: | |
| audio_dict = { | |
| "waveform": audio.unsqueeze(0).to(self.device), | |
| "sample_rate": self.SAMPLE_RATE, | |
| } | |
| vad_output = self._vad_pipeline(audio_dict) | |
| regions = [(seg.start, seg.end) for seg in vad_output.get_timeline().support()] | |
| logger.info(f"Pyannote VAD: {len(regions)} speech regions found") | |
| return regions | |
| except Exception as e: | |
| logger.warning(f"Pyannote VAD failed: {e}. Using energy VAD.") | |
| return self._energy_vad(audio) | |
| def _sliding_window_segments(self, speech_regions: List[tuple]) -> List[tuple]: | |
| segments = [] | |
| for region_start, region_end in speech_regions: | |
| duration = region_end - region_start | |
| if duration < self.MIN_SEGMENT_DURATION: | |
| continue | |
| t = region_start | |
| while t + self.WINDOW_DURATION <= region_end: | |
| segments.append((t, t + self.WINDOW_DURATION)) | |
| t += self.WINDOW_STEP | |
| if region_end - t >= self.MIN_SEGMENT_DURATION: | |
| segments.append((t, region_end)) | |
| return segments | |
| def load_audio(self, path: Union[str, Path, BinaryIO]) -> tuple: | |
| waveform, sample_rate = torchaudio.load(path) | |
| return waveform, sample_rate | |
| def process( | |
| self, | |
| audio: Union[str, Path, torch.Tensor], | |
| sample_rate: int = None, | |
| num_speakers: Optional[int] = None, | |
| ) -> DiarizationResult: | |
| t_start = time.time() | |
| if isinstance(audio, (str, Path)): | |
| waveform, sample_rate = self.load_audio(audio) | |
| audio_tensor = self._to_mono_1d(waveform) | |
| else: | |
| assert sample_rate is not None, "sample_rate required when passing tensor" | |
| audio_tensor = self._to_mono_1d(audio) | |
| num_samples = int(audio_tensor.numel()) | |
| audio_duration = num_samples / float(sample_rate) | |
| logger.info(f"Processing {audio_duration:.1f}s audio at {sample_rate}Hz") | |
| if num_samples == 0: | |
| logger.warning("Received empty audio input.") | |
| return DiarizationResult( | |
| segments=[], | |
| num_speakers=0, | |
| audio_duration=0.0, | |
| processing_time=time.time() - t_start, | |
| sample_rate=sample_rate, | |
| ) | |
| k = num_speakers or self.num_speakers | |
| pyannote_result = self._run_full_pyannote( | |
| audio=audio, | |
| sample_rate=sample_rate, | |
| num_speakers=k, | |
| audio_duration=audio_duration, | |
| t_start=t_start, | |
| ) | |
| if pyannote_result is not None: | |
| return pyannote_result | |
| processed = self.embedder.preprocess_audio(audio_tensor, sample_rate) | |
| speech_regions = self._get_speech_regions(processed) | |
| if not speech_regions: | |
| logger.warning("No speech detected in audio.") | |
| return DiarizationResult( | |
| segments=[], | |
| num_speakers=0, | |
| audio_duration=audio_duration, | |
| processing_time=time.time() - t_start, | |
| sample_rate=sample_rate, | |
| ) | |
| windows = self._sliding_window_segments(speech_regions) | |
| logger.info(f"Generated {len(windows)} embedding windows") | |
| embeddings, valid_windows = self.embedder.extract_embeddings_from_segments( | |
| processed, self.SAMPLE_RATE, windows | |
| ) | |
| if len(embeddings) == 0: | |
| logger.warning("No valid embeddings extracted.") | |
| return DiarizationResult( | |
| segments=[], | |
| num_speakers=0, | |
| audio_duration=audio_duration, | |
| processing_time=time.time() - t_start, | |
| sample_rate=sample_rate, | |
| ) | |
| labels = self.clusterer.cluster(embeddings, num_speakers=k) | |
| merged = self.clusterer.merge_consecutive_same_speaker( | |
| valid_windows, labels, gap_tolerance=0.45 | |
| ) | |
| speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)} | |
| segments = [ | |
| DiarizationSegment(start=start, end=end, speaker=speaker_names[spk_id]) | |
| for start, end, spk_id in merged | |
| ] | |
| num_unique = len(set(labels)) | |
| processing_time = time.time() - t_start | |
| logger.success( | |
| f"Fallback diarization complete: {num_unique} speakers, " | |
| f"{len(segments)} segments, {processing_time:.2f}s" | |
| ) | |
| return DiarizationResult( | |
| segments=segments, | |
| num_speakers=num_unique, | |
| audio_duration=audio_duration, | |
| processing_time=processing_time, | |
| sample_rate=sample_rate, | |
| ) | |