| """Stage 1 — Speaker Diarization Module. |
| |
| Uses pyannote-audio 4.x (speaker-diarization-3.1) to identify who spoke when. |
| """ |
|
|
| import logging |
| import os |
| from dataclasses import dataclass |
|
|
| import torch |
| import torchaudio |
| from pyannote.audio import Pipeline |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class DiarSegment: |
| speaker_id: str |
| start: float |
| end: float |
|
|
|
|
| _pipeline: Pipeline | None = None |
|
|
|
|
| def _load_pipeline(model: str, device: str) -> Pipeline: |
| """Load and cache the pyannote diarization pipeline.""" |
| global _pipeline |
| if _pipeline is not None: |
| return _pipeline |
|
|
| hf_token = os.environ.get("HF_TOKEN") |
| if not hf_token: |
| raise RuntimeError( |
| "HF_TOKEN environment variable required for pyannote model download. " |
| "Set it with: export HF_TOKEN=your_huggingface_token" |
| ) |
|
|
| logger.info("Loading pyannote pipeline: %s", model) |
| _pipeline = Pipeline.from_pretrained(model, token=hf_token) |
|
|
| torch_device = torch.device(device if torch.cuda.is_available() and "cuda" in device else "cpu") |
| _pipeline.to(torch_device) |
| logger.info("Pyannote pipeline loaded on %s", torch_device) |
|
|
| return _pipeline |
|
|
|
|
| def _merge_adjacent( |
| segments: list[DiarSegment], gap_threshold: float |
| ) -> list[DiarSegment]: |
| """Merge same-speaker segments closer than gap_threshold seconds.""" |
| if not segments: |
| return segments |
|
|
| merged: list[DiarSegment] = [segments[0]] |
| for seg in segments[1:]: |
| prev = merged[-1] |
| if seg.speaker_id == prev.speaker_id and (seg.start - prev.end) < gap_threshold: |
| prev.end = max(prev.end, seg.end) |
| else: |
| merged.append(seg) |
| return merged |
|
|
|
|
| def diarize(audio_path: str, config: dict) -> list[DiarSegment]: |
| """Run speaker diarization on audio. |
| |
| Args: |
| audio_path: Path to 16kHz mono WAV file. |
| config: stage1.diarization config section. |
| |
| Returns: |
| List of DiarSegment sorted by start time. |
| """ |
| model = config.get("model", "pyannote/speaker-diarization-3.1") |
| num_speakers = config.get("num_speakers", 2) |
| merge_gap = config.get("merge_gap_sec", 0.15) |
| device = config.get("device", "cuda:0") |
|
|
| pipeline = _load_pipeline(model, device) |
|
|
| logger.info("Running diarization (num_speakers=%d) on %s", num_speakers, audio_path) |
|
|
| |
| waveform, sample_rate = torchaudio.load(audio_path) |
| audio_dict = {"waveform": waveform, "sample_rate": sample_rate} |
| output = pipeline(audio_dict, num_speakers=num_speakers) |
|
|
| |
| if hasattr(output, 'speaker_diarization'): |
| annotation = output.speaker_diarization |
| elif hasattr(output, 'itertracks'): |
| annotation = output |
| else: |
| raise RuntimeError(f"Unexpected pyannote output type: {type(output)}") |
|
|
| segments: list[DiarSegment] = [] |
| for turn, _, speaker in annotation.itertracks(yield_label=True): |
| segments.append(DiarSegment( |
| speaker_id=speaker, |
| start=round(turn.start, 3), |
| end=round(turn.end, 3), |
| )) |
|
|
| segments.sort(key=lambda s: s.start) |
| segments = _merge_adjacent(segments, merge_gap) |
|
|
| |
| speaker_map: dict[str, str] = {} |
| for seg in segments: |
| if seg.speaker_id not in speaker_map: |
| speaker_map[seg.speaker_id] = f"speaker_{len(speaker_map)}" |
| seg.speaker_id = speaker_map[seg.speaker_id] |
|
|
| logger.info( |
| "Diarization complete: %d segments, %d speakers", |
| len(segments), len(speaker_map), |
| ) |
|
|
| if len(speaker_map) < 2: |
| logger.warning("Only %d speaker(s) detected", len(speaker_map)) |
|
|
| return segments |