Spaces:
Running
Running
File size: 6,074 Bytes
975f9a3 4bef769 975f9a3 48c3b28 975f9a3 540cd4c 975f9a3 540cd4c 975f9a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | from __future__ import annotations
import threading
from collections import Counter
from typing import Any
try:
from .audio import round_sec, seconds_from_samples
from .config import VoiceRuntimeConfig
except ImportError: # HF flat-root execution fallback
from audio import round_sec, seconds_from_samples
from config import VoiceRuntimeConfig
class DiarizationRuntime:
_lock = threading.Lock()
_pipeline = None
_loaded_id: str | None = None
@classmethod
def get_pipeline(cls, config: VoiceRuntimeConfig):
with cls._lock:
if cls._pipeline is not None and cls._loaded_id == config.diarization_model_id:
return cls._pipeline
if not config.hf_token:
raise RuntimeError("HF_TOKEN is required for diarization model download/use.")
try:
from pyannote.audio import Pipeline
except Exception as exc:
raise RuntimeError("pyannote.audio is not installed; install it to enable diarization.") from exc
pipeline = Pipeline.from_pretrained(config.diarization_model_id, token=config.hf_token)
try:
pipeline.to("cpu")
except Exception:
pass
cls._pipeline = pipeline
cls._loaded_id = config.diarization_model_id
return cls._pipeline
def _segment_to_payload(start_sec: float, end_sec: float, speaker: str, sample_rate: int) -> dict[str, Any]:
start_sample = int(round(start_sec * sample_rate))
end_sample = int(round(end_sec * sample_rate))
end_sample = max(start_sample + 1, end_sample)
return {
"speaker": speaker,
"start_sec": round_sec(start_sec),
"end_sec": round_sec(end_sec),
"duration_sec": round_sec(max(0.0, end_sec - start_sec)),
"start_sample": start_sample,
"end_sample": end_sample,
}
def _resolve_annotation(diarization_output: Any) -> Any:
"""Return an object exposing itertracks(yield_label=True)."""
if hasattr(diarization_output, "itertracks"):
return diarization_output
# Newer pyannote pipelines may return wrappers like DiarizeOutput.
for attr in ("speaker_diarization", "annotation", "diarization"):
candidate = getattr(diarization_output, attr, None)
if candidate is not None and hasattr(candidate, "itertracks"):
return candidate
if isinstance(diarization_output, dict):
for key in ("speaker_diarization", "annotation", "diarization"):
candidate = diarization_output.get(key)
if candidate is not None and hasattr(candidate, "itertracks"):
return candidate
raise RuntimeError(
"Unsupported diarization output type "
f"{type(diarization_output).__name__}; expected Annotation-compatible object."
)
def run_diarization(wav_path: str, config: VoiceRuntimeConfig, sample_rate: int) -> list[dict[str, Any]]:
if not config.diarization_enabled:
return []
pipeline = DiarizationRuntime.get_pipeline(config)
kwargs: dict[str, Any] = {}
if config.diarization_min_speakers > 0:
kwargs["min_speakers"] = config.diarization_min_speakers
if config.diarization_max_speakers > 0:
kwargs["max_speakers"] = config.diarization_max_speakers
diarization_output = pipeline(wav_path, **kwargs) if kwargs else pipeline(wav_path)
annotation = _resolve_annotation(diarization_output)
diarization_segments: list[dict[str, Any]] = []
for turn, _, speaker in annotation.itertracks(yield_label=True):
diarization_segments.append(
_segment_to_payload(
start_sec=float(turn.start),
end_sec=float(turn.end),
speaker=str(speaker),
sample_rate=sample_rate,
)
)
diarization_segments.sort(key=lambda item: item["start_sec"])
return diarization_segments
def _find_speaker_for_time(timestamp_sec: float, diarization_segments: list[dict[str, Any]]) -> str | None:
for segment in diarization_segments:
if segment["start_sec"] <= timestamp_sec < segment["end_sec"]:
return segment["speaker"]
return None
def attach_speakers_to_words(words: list[dict[str, Any]], diarization_segments: list[dict[str, Any]]) -> None:
if not diarization_segments:
return
for word in words:
midpoint = (float(word["start_sec"]) + float(word["end_sec"])) / 2.0
speaker = _find_speaker_for_time(midpoint, diarization_segments)
word["speaker"] = speaker or "unknown"
def attach_speakers_to_segments(
segments: list[dict[str, Any]],
words: list[dict[str, Any]],
sample_rate: int,
) -> None:
del sample_rate
if not words:
for segment in segments:
segment["speaker"] = "unknown"
return
words_by_segment: dict[int, list[str]] = {}
for word in words:
seg_idx = int(word.get("segment_index", -1))
if seg_idx < 0:
continue
words_by_segment.setdefault(seg_idx, []).append(str(word.get("speaker", "unknown")))
for segment in segments:
idx = int(segment.get("index", -1))
labels = words_by_segment.get(idx, [])
if not labels:
segment["speaker"] = "unknown"
continue
top = Counter(labels).most_common(1)[0][0]
segment["speaker"] = top
def build_diarization_summary(diarization_segments: list[dict[str, Any]]) -> dict[str, Any]:
speakers = sorted({item["speaker"] for item in diarization_segments})
total_speech_sec = round_sec(sum(float(item["duration_sec"]) for item in diarization_segments))
return {
"speaker_count": len(speakers),
"speakers": speakers,
"segment_count": len(diarization_segments),
"total_speech_sec": total_speech_sec,
}
|