"""Stage 1 — Speaker Diarization Module. Uses pyannote-audio 4.x (speaker-diarization-3.1) to identify who spoke when. """ import logging import os from dataclasses import dataclass import torch import torchaudio from pyannote.audio import Pipeline logger = logging.getLogger(__name__) @dataclass class DiarSegment: speaker_id: str start: float end: float _pipeline: Pipeline | None = None def _load_pipeline(model: str, device: str) -> Pipeline: """Load and cache the pyannote diarization pipeline.""" global _pipeline if _pipeline is not None: return _pipeline hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise RuntimeError( "HF_TOKEN environment variable required for pyannote model download. " "Set it with: export HF_TOKEN=your_huggingface_token" ) logger.info("Loading pyannote pipeline: %s", model) _pipeline = Pipeline.from_pretrained(model, token=hf_token) torch_device = torch.device(device if torch.cuda.is_available() and "cuda" in device else "cpu") _pipeline.to(torch_device) logger.info("Pyannote pipeline loaded on %s", torch_device) return _pipeline def _merge_adjacent( segments: list[DiarSegment], gap_threshold: float ) -> list[DiarSegment]: """Merge same-speaker segments closer than gap_threshold seconds.""" if not segments: return segments merged: list[DiarSegment] = [segments[0]] for seg in segments[1:]: prev = merged[-1] if seg.speaker_id == prev.speaker_id and (seg.start - prev.end) < gap_threshold: prev.end = max(prev.end, seg.end) else: merged.append(seg) return merged def diarize(audio_path: str, config: dict) -> list[DiarSegment]: """Run speaker diarization on audio. Args: audio_path: Path to 16kHz mono WAV file. config: stage1.diarization config section. Returns: List of DiarSegment sorted by start time. """ model = config.get("model", "pyannote/speaker-diarization-3.1") num_speakers = config.get("num_speakers", 2) merge_gap = config.get("merge_gap_sec", 0.15) device = config.get("device", "cuda:0") pipeline = _load_pipeline(model, device) logger.info("Running diarization (num_speakers=%d) on %s", num_speakers, audio_path) # pyannote 4.x: pass in-memory waveform to avoid torchcodec dependency waveform, sample_rate = torchaudio.load(audio_path) audio_dict = {"waveform": waveform, "sample_rate": sample_rate} output = pipeline(audio_dict, num_speakers=num_speakers) # pyannote 4.x returns DiarizeOutput with .speaker_diarization Annotation if hasattr(output, 'speaker_diarization'): annotation = output.speaker_diarization elif hasattr(output, 'itertracks'): annotation = output else: raise RuntimeError(f"Unexpected pyannote output type: {type(output)}") segments: list[DiarSegment] = [] for turn, _, speaker in annotation.itertracks(yield_label=True): segments.append(DiarSegment( speaker_id=speaker, start=round(turn.start, 3), end=round(turn.end, 3), )) segments.sort(key=lambda s: s.start) segments = _merge_adjacent(segments, merge_gap) # Normalize speaker labels to speaker_0 / speaker_1 speaker_map: dict[str, str] = {} for seg in segments: if seg.speaker_id not in speaker_map: speaker_map[seg.speaker_id] = f"speaker_{len(speaker_map)}" seg.speaker_id = speaker_map[seg.speaker_id] logger.info( "Diarization complete: %d segments, %d speakers", len(segments), len(speaker_map), ) if len(speaker_map) < 2: logger.warning("Only %d speaker(s) detected", len(speaker_map)) return segments