voice-intelligence / diarization.py
unknownfriend00007's picture
Upload 10 files
540cd4c verified
from __future__ import annotations
import threading
from collections import Counter
from typing import Any
try:
from .audio import round_sec, seconds_from_samples
from .config import VoiceRuntimeConfig
except ImportError: # HF flat-root execution fallback
from audio import round_sec, seconds_from_samples
from config import VoiceRuntimeConfig
class DiarizationRuntime:
_lock = threading.Lock()
_pipeline = None
_loaded_id: str | None = None
@classmethod
def get_pipeline(cls, config: VoiceRuntimeConfig):
with cls._lock:
if cls._pipeline is not None and cls._loaded_id == config.diarization_model_id:
return cls._pipeline
if not config.hf_token:
raise RuntimeError("HF_TOKEN is required for diarization model download/use.")
try:
from pyannote.audio import Pipeline
except Exception as exc:
raise RuntimeError("pyannote.audio is not installed; install it to enable diarization.") from exc
pipeline = Pipeline.from_pretrained(config.diarization_model_id, token=config.hf_token)
try:
pipeline.to("cpu")
except Exception:
pass
cls._pipeline = pipeline
cls._loaded_id = config.diarization_model_id
return cls._pipeline
def _segment_to_payload(start_sec: float, end_sec: float, speaker: str, sample_rate: int) -> dict[str, Any]:
start_sample = int(round(start_sec * sample_rate))
end_sample = int(round(end_sec * sample_rate))
end_sample = max(start_sample + 1, end_sample)
return {
"speaker": speaker,
"start_sec": round_sec(start_sec),
"end_sec": round_sec(end_sec),
"duration_sec": round_sec(max(0.0, end_sec - start_sec)),
"start_sample": start_sample,
"end_sample": end_sample,
}
def _resolve_annotation(diarization_output: Any) -> Any:
"""Return an object exposing itertracks(yield_label=True)."""
if hasattr(diarization_output, "itertracks"):
return diarization_output
# Newer pyannote pipelines may return wrappers like DiarizeOutput.
for attr in ("speaker_diarization", "annotation", "diarization"):
candidate = getattr(diarization_output, attr, None)
if candidate is not None and hasattr(candidate, "itertracks"):
return candidate
if isinstance(diarization_output, dict):
for key in ("speaker_diarization", "annotation", "diarization"):
candidate = diarization_output.get(key)
if candidate is not None and hasattr(candidate, "itertracks"):
return candidate
raise RuntimeError(
"Unsupported diarization output type "
f"{type(diarization_output).__name__}; expected Annotation-compatible object."
)
def run_diarization(wav_path: str, config: VoiceRuntimeConfig, sample_rate: int) -> list[dict[str, Any]]:
if not config.diarization_enabled:
return []
pipeline = DiarizationRuntime.get_pipeline(config)
kwargs: dict[str, Any] = {}
if config.diarization_min_speakers > 0:
kwargs["min_speakers"] = config.diarization_min_speakers
if config.diarization_max_speakers > 0:
kwargs["max_speakers"] = config.diarization_max_speakers
diarization_output = pipeline(wav_path, **kwargs) if kwargs else pipeline(wav_path)
annotation = _resolve_annotation(diarization_output)
diarization_segments: list[dict[str, Any]] = []
for turn, _, speaker in annotation.itertracks(yield_label=True):
diarization_segments.append(
_segment_to_payload(
start_sec=float(turn.start),
end_sec=float(turn.end),
speaker=str(speaker),
sample_rate=sample_rate,
)
)
diarization_segments.sort(key=lambda item: item["start_sec"])
return diarization_segments
def _find_speaker_for_time(timestamp_sec: float, diarization_segments: list[dict[str, Any]]) -> str | None:
for segment in diarization_segments:
if segment["start_sec"] <= timestamp_sec < segment["end_sec"]:
return segment["speaker"]
return None
def attach_speakers_to_words(words: list[dict[str, Any]], diarization_segments: list[dict[str, Any]]) -> None:
if not diarization_segments:
return
for word in words:
midpoint = (float(word["start_sec"]) + float(word["end_sec"])) / 2.0
speaker = _find_speaker_for_time(midpoint, diarization_segments)
word["speaker"] = speaker or "unknown"
def attach_speakers_to_segments(
segments: list[dict[str, Any]],
words: list[dict[str, Any]],
sample_rate: int,
) -> None:
del sample_rate
if not words:
for segment in segments:
segment["speaker"] = "unknown"
return
words_by_segment: dict[int, list[str]] = {}
for word in words:
seg_idx = int(word.get("segment_index", -1))
if seg_idx < 0:
continue
words_by_segment.setdefault(seg_idx, []).append(str(word.get("speaker", "unknown")))
for segment in segments:
idx = int(segment.get("index", -1))
labels = words_by_segment.get(idx, [])
if not labels:
segment["speaker"] = "unknown"
continue
top = Counter(labels).most_common(1)[0][0]
segment["speaker"] = top
def build_diarization_summary(diarization_segments: list[dict[str, Any]]) -> dict[str, Any]:
speakers = sorted({item["speaker"] for item in diarization_segments})
total_speech_sec = round_sec(sum(float(item["duration_sec"]) for item in diarization_segments))
return {
"speaker_count": len(speakers),
"speakers": speakers,
"segment_count": len(diarization_segments),
"total_speech_sec": total_speech_sec,
}