from __future__ import annotations import threading from collections import Counter from typing import Any try: from .audio import round_sec, seconds_from_samples from .config import VoiceRuntimeConfig except ImportError: # HF flat-root execution fallback from audio import round_sec, seconds_from_samples from config import VoiceRuntimeConfig class DiarizationRuntime: _lock = threading.Lock() _pipeline = None _loaded_id: str | None = None @classmethod def get_pipeline(cls, config: VoiceRuntimeConfig): with cls._lock: if cls._pipeline is not None and cls._loaded_id == config.diarization_model_id: return cls._pipeline if not config.hf_token: raise RuntimeError("HF_TOKEN is required for diarization model download/use.") try: from pyannote.audio import Pipeline except Exception as exc: raise RuntimeError("pyannote.audio is not installed; install it to enable diarization.") from exc pipeline = Pipeline.from_pretrained(config.diarization_model_id, token=config.hf_token) try: pipeline.to("cpu") except Exception: pass cls._pipeline = pipeline cls._loaded_id = config.diarization_model_id return cls._pipeline def _segment_to_payload(start_sec: float, end_sec: float, speaker: str, sample_rate: int) -> dict[str, Any]: start_sample = int(round(start_sec * sample_rate)) end_sample = int(round(end_sec * sample_rate)) end_sample = max(start_sample + 1, end_sample) return { "speaker": speaker, "start_sec": round_sec(start_sec), "end_sec": round_sec(end_sec), "duration_sec": round_sec(max(0.0, end_sec - start_sec)), "start_sample": start_sample, "end_sample": end_sample, } def _resolve_annotation(diarization_output: Any) -> Any: """Return an object exposing itertracks(yield_label=True).""" if hasattr(diarization_output, "itertracks"): return diarization_output # Newer pyannote pipelines may return wrappers like DiarizeOutput. for attr in ("speaker_diarization", "annotation", "diarization"): candidate = getattr(diarization_output, attr, None) if candidate is not None and hasattr(candidate, "itertracks"): return candidate if isinstance(diarization_output, dict): for key in ("speaker_diarization", "annotation", "diarization"): candidate = diarization_output.get(key) if candidate is not None and hasattr(candidate, "itertracks"): return candidate raise RuntimeError( "Unsupported diarization output type " f"{type(diarization_output).__name__}; expected Annotation-compatible object." ) def run_diarization(wav_path: str, config: VoiceRuntimeConfig, sample_rate: int) -> list[dict[str, Any]]: if not config.diarization_enabled: return [] pipeline = DiarizationRuntime.get_pipeline(config) kwargs: dict[str, Any] = {} if config.diarization_min_speakers > 0: kwargs["min_speakers"] = config.diarization_min_speakers if config.diarization_max_speakers > 0: kwargs["max_speakers"] = config.diarization_max_speakers diarization_output = pipeline(wav_path, **kwargs) if kwargs else pipeline(wav_path) annotation = _resolve_annotation(diarization_output) diarization_segments: list[dict[str, Any]] = [] for turn, _, speaker in annotation.itertracks(yield_label=True): diarization_segments.append( _segment_to_payload( start_sec=float(turn.start), end_sec=float(turn.end), speaker=str(speaker), sample_rate=sample_rate, ) ) diarization_segments.sort(key=lambda item: item["start_sec"]) return diarization_segments def _find_speaker_for_time(timestamp_sec: float, diarization_segments: list[dict[str, Any]]) -> str | None: for segment in diarization_segments: if segment["start_sec"] <= timestamp_sec < segment["end_sec"]: return segment["speaker"] return None def attach_speakers_to_words(words: list[dict[str, Any]], diarization_segments: list[dict[str, Any]]) -> None: if not diarization_segments: return for word in words: midpoint = (float(word["start_sec"]) + float(word["end_sec"])) / 2.0 speaker = _find_speaker_for_time(midpoint, diarization_segments) word["speaker"] = speaker or "unknown" def attach_speakers_to_segments( segments: list[dict[str, Any]], words: list[dict[str, Any]], sample_rate: int, ) -> None: del sample_rate if not words: for segment in segments: segment["speaker"] = "unknown" return words_by_segment: dict[int, list[str]] = {} for word in words: seg_idx = int(word.get("segment_index", -1)) if seg_idx < 0: continue words_by_segment.setdefault(seg_idx, []).append(str(word.get("speaker", "unknown"))) for segment in segments: idx = int(segment.get("index", -1)) labels = words_by_segment.get(idx, []) if not labels: segment["speaker"] = "unknown" continue top = Counter(labels).most_common(1)[0][0] segment["speaker"] = top def build_diarization_summary(diarization_segments: list[dict[str, Any]]) -> dict[str, Any]: speakers = sorted({item["speaker"] for item in diarization_segments}) total_speech_sec = round_sec(sum(float(item["duration_sec"]) for item in diarization_segments)) return { "speaker_count": len(speakers), "speakers": speakers, "segment_count": len(diarization_segments), "total_speech_sec": total_speech_sec, }