Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import threading | |
| from collections import Counter | |
| from typing import Any | |
| try: | |
| from .audio import round_sec, seconds_from_samples | |
| from .config import VoiceRuntimeConfig | |
| except ImportError: # HF flat-root execution fallback | |
| from audio import round_sec, seconds_from_samples | |
| from config import VoiceRuntimeConfig | |
| class DiarizationRuntime: | |
| _lock = threading.Lock() | |
| _pipeline = None | |
| _loaded_id: str | None = None | |
| def get_pipeline(cls, config: VoiceRuntimeConfig): | |
| with cls._lock: | |
| if cls._pipeline is not None and cls._loaded_id == config.diarization_model_id: | |
| return cls._pipeline | |
| if not config.hf_token: | |
| raise RuntimeError("HF_TOKEN is required for diarization model download/use.") | |
| try: | |
| from pyannote.audio import Pipeline | |
| except Exception as exc: | |
| raise RuntimeError("pyannote.audio is not installed; install it to enable diarization.") from exc | |
| pipeline = Pipeline.from_pretrained(config.diarization_model_id, token=config.hf_token) | |
| try: | |
| pipeline.to("cpu") | |
| except Exception: | |
| pass | |
| cls._pipeline = pipeline | |
| cls._loaded_id = config.diarization_model_id | |
| return cls._pipeline | |
| def _segment_to_payload(start_sec: float, end_sec: float, speaker: str, sample_rate: int) -> dict[str, Any]: | |
| start_sample = int(round(start_sec * sample_rate)) | |
| end_sample = int(round(end_sec * sample_rate)) | |
| end_sample = max(start_sample + 1, end_sample) | |
| return { | |
| "speaker": speaker, | |
| "start_sec": round_sec(start_sec), | |
| "end_sec": round_sec(end_sec), | |
| "duration_sec": round_sec(max(0.0, end_sec - start_sec)), | |
| "start_sample": start_sample, | |
| "end_sample": end_sample, | |
| } | |
| def _resolve_annotation(diarization_output: Any) -> Any: | |
| """Return an object exposing itertracks(yield_label=True).""" | |
| if hasattr(diarization_output, "itertracks"): | |
| return diarization_output | |
| # Newer pyannote pipelines may return wrappers like DiarizeOutput. | |
| for attr in ("speaker_diarization", "annotation", "diarization"): | |
| candidate = getattr(diarization_output, attr, None) | |
| if candidate is not None and hasattr(candidate, "itertracks"): | |
| return candidate | |
| if isinstance(diarization_output, dict): | |
| for key in ("speaker_diarization", "annotation", "diarization"): | |
| candidate = diarization_output.get(key) | |
| if candidate is not None and hasattr(candidate, "itertracks"): | |
| return candidate | |
| raise RuntimeError( | |
| "Unsupported diarization output type " | |
| f"{type(diarization_output).__name__}; expected Annotation-compatible object." | |
| ) | |
| def run_diarization(wav_path: str, config: VoiceRuntimeConfig, sample_rate: int) -> list[dict[str, Any]]: | |
| if not config.diarization_enabled: | |
| return [] | |
| pipeline = DiarizationRuntime.get_pipeline(config) | |
| kwargs: dict[str, Any] = {} | |
| if config.diarization_min_speakers > 0: | |
| kwargs["min_speakers"] = config.diarization_min_speakers | |
| if config.diarization_max_speakers > 0: | |
| kwargs["max_speakers"] = config.diarization_max_speakers | |
| diarization_output = pipeline(wav_path, **kwargs) if kwargs else pipeline(wav_path) | |
| annotation = _resolve_annotation(diarization_output) | |
| diarization_segments: list[dict[str, Any]] = [] | |
| for turn, _, speaker in annotation.itertracks(yield_label=True): | |
| diarization_segments.append( | |
| _segment_to_payload( | |
| start_sec=float(turn.start), | |
| end_sec=float(turn.end), | |
| speaker=str(speaker), | |
| sample_rate=sample_rate, | |
| ) | |
| ) | |
| diarization_segments.sort(key=lambda item: item["start_sec"]) | |
| return diarization_segments | |
| def _find_speaker_for_time(timestamp_sec: float, diarization_segments: list[dict[str, Any]]) -> str | None: | |
| for segment in diarization_segments: | |
| if segment["start_sec"] <= timestamp_sec < segment["end_sec"]: | |
| return segment["speaker"] | |
| return None | |
| def attach_speakers_to_words(words: list[dict[str, Any]], diarization_segments: list[dict[str, Any]]) -> None: | |
| if not diarization_segments: | |
| return | |
| for word in words: | |
| midpoint = (float(word["start_sec"]) + float(word["end_sec"])) / 2.0 | |
| speaker = _find_speaker_for_time(midpoint, diarization_segments) | |
| word["speaker"] = speaker or "unknown" | |
| def attach_speakers_to_segments( | |
| segments: list[dict[str, Any]], | |
| words: list[dict[str, Any]], | |
| sample_rate: int, | |
| ) -> None: | |
| del sample_rate | |
| if not words: | |
| for segment in segments: | |
| segment["speaker"] = "unknown" | |
| return | |
| words_by_segment: dict[int, list[str]] = {} | |
| for word in words: | |
| seg_idx = int(word.get("segment_index", -1)) | |
| if seg_idx < 0: | |
| continue | |
| words_by_segment.setdefault(seg_idx, []).append(str(word.get("speaker", "unknown"))) | |
| for segment in segments: | |
| idx = int(segment.get("index", -1)) | |
| labels = words_by_segment.get(idx, []) | |
| if not labels: | |
| segment["speaker"] = "unknown" | |
| continue | |
| top = Counter(labels).most_common(1)[0][0] | |
| segment["speaker"] = top | |
| def build_diarization_summary(diarization_segments: list[dict[str, Any]]) -> dict[str, Any]: | |
| speakers = sorted({item["speaker"] for item in diarization_segments}) | |
| total_speech_sec = round_sec(sum(float(item["duration_sec"]) for item in diarization_segments)) | |
| return { | |
| "speaker_count": len(speakers), | |
| "speakers": speakers, | |
| "segment_count": len(diarization_segments), | |
| "total_speech_sec": total_speech_sec, | |
| } | |