transcribe-diarize / src /diarization_service.py
Ratnesh-dev's picture
Fix Parakeet Call Error And Remove Unused API Parameters
68728b4
import tempfile
from typing import Any
import gradio as gr
from src.constants import PYANNOTE_COMMUNITY_1
from src.utils import build_audio_chunk_plan, extract_audio_clip, get_audio_duration_seconds, parse_model_options
def parse_and_validate_diarization_request(
audio_file: str,
model_options_json: str,
) -> dict:
if audio_file is None:
raise gr.Error("No audio file submitted. Upload an audio file first.")
return parse_model_options(model_options_json)
def sanitize_model_options_for_response(model_options: dict[str, Any]) -> dict[str, Any]:
redacted = dict(model_options)
for key in ("hf_token", "huggingface_token"):
if key in redacted:
redacted[key] = "***redacted***"
return redacted
def normalize_segments_with_offset(segments: list[dict[str, Any]], offset: float) -> list[dict[str, Any]]:
normalized = []
for seg in segments:
start = float(seg.get("start", 0.0)) + offset
end = float(seg.get("end", 0.0)) + offset
normalized.append({
**seg,
"start": round(start, 4),
"end": round(end, 4),
})
return normalized
def merge_adjacent_segments(
segments: list[dict[str, Any]],
max_gap_s: float = 0.2,
) -> list[dict[str, Any]]:
if not segments:
return []
ordered = sorted(segments, key=lambda x: (float(x["start"]), float(x["end"])))
merged = [ordered[0]]
for seg in ordered[1:]:
prev = merged[-1]
same_speaker = str(prev.get("speaker")) == str(seg.get("speaker"))
gap = float(seg["start"]) - float(prev["end"])
if same_speaker and gap <= max_gap_s:
prev["end"] = round(max(float(prev["end"]), float(seg["end"])), 4)
else:
merged.append(seg)
return merged
def build_diarization_response(
audio_file: str,
model_options: dict,
duration_seconds: float | None,
chunk_plan: list[dict[str, Any]],
chunk_results: list[dict[str, Any]],
stitched_segments: dict[str, list[dict[str, Any]]],
zerogpu_timing: dict,
) -> dict:
return {
"model": PYANNOTE_COMMUNITY_1,
"audio_file": str(audio_file),
"duration_seconds": duration_seconds,
"model_options": sanitize_model_options_for_response(model_options),
"long_audio": {
"chunk_count": len(chunk_plan),
"chunk_plan": chunk_plan,
"chunked": len(chunk_plan) > 1,
},
"zerogpu_timing": zerogpu_timing,
"raw_output": {
"chunk_results": chunk_results,
"stitched": stitched_segments,
},
}
def run_chunked_diarization(
audio_file: str,
model_options: dict,
gpu_chunk_runner,
) -> dict:
duration_seconds = get_audio_duration_seconds(audio_file)
# Default to single-pass diarization up to 2 hours on ZeroGPU.
chunk_threshold_s = float(model_options.get("long_audio_chunk_threshold_s", 7200))
chunk_duration_s = float(model_options.get("chunk_duration_s", 7200))
chunk_overlap_s = float(model_options.get("chunk_overlap_s", 0))
merge_gap_s = float(model_options.get("merge_gap_s", 0.2))
if chunk_threshold_s <= 0:
raise gr.Error("long_audio_chunk_threshold_s must be > 0")
if chunk_duration_s <= 0:
raise gr.Error("chunk_duration_s must be > 0")
if chunk_overlap_s < 0:
raise gr.Error("chunk_overlap_s must be >= 0")
if chunk_overlap_s >= chunk_duration_s:
raise gr.Error("chunk_overlap_s must be smaller than chunk_duration_s")
if duration_seconds is None or duration_seconds <= chunk_threshold_s:
chunk_plan = [{"index": 0, "start": 0.0, "end": duration_seconds, "duration": duration_seconds}]
else:
chunk_plan = build_audio_chunk_plan(
audio_file=audio_file,
chunk_duration_s=chunk_duration_s,
chunk_overlap_s=chunk_overlap_s,
)
chunk_results: list[dict[str, Any]] = []
stitched_standard: list[dict[str, Any]] = []
stitched_exclusive: list[dict[str, Any]] = []
total_gpu_window_seconds = 0.0
total_inference_seconds = 0.0
with tempfile.TemporaryDirectory() as tmpdir:
for chunk in chunk_plan:
start = float(chunk["start"])
duration_raw = chunk.get("duration")
duration = None if duration_raw is None else float(duration_raw)
if duration is not None and duration <= 0:
continue
if len(chunk_plan) == 1:
chunk_audio_file = audio_file
else:
if duration is None:
raise gr.Error(
"Chunk duration is unknown. Ensure ffprobe is available to process long audio chunking."
)
chunk_audio_file = extract_audio_clip(
source_audio_file=audio_file,
start_seconds=start,
duration_seconds=duration,
tmpdir=tmpdir,
)
result = gpu_chunk_runner(
audio_file=chunk_audio_file,
model_options=model_options,
)
total_gpu_window_seconds += float(result["zerogpu_timing"].get("gpu_window_seconds", 0.0))
total_inference_seconds += float(result["zerogpu_timing"].get("inference_seconds", 0.0))
raw = result["raw_output"]
standard_segments = raw.get("speaker_diarization", {}).get("segments", [])
exclusive_segments = raw.get("exclusive_speaker_diarization", {}).get("segments", [])
standard_shifted = normalize_segments_with_offset(standard_segments, offset=start)
exclusive_shifted = normalize_segments_with_offset(exclusive_segments, offset=start)
chunk_results.append(
{
"chunk": chunk,
"zerogpu_timing": result["zerogpu_timing"],
"raw_output": raw,
"shifted_segments": {
"speaker_diarization": standard_shifted,
"exclusive_speaker_diarization": exclusive_shifted,
},
}
)
stitched_standard.extend(standard_shifted)
stitched_exclusive.extend(exclusive_shifted)
stitched = {
"speaker_diarization": merge_adjacent_segments(stitched_standard, max_gap_s=merge_gap_s),
"exclusive_speaker_diarization": merge_adjacent_segments(stitched_exclusive, max_gap_s=merge_gap_s),
}
return build_diarization_response(
audio_file=audio_file,
model_options=model_options,
duration_seconds=duration_seconds,
chunk_plan=chunk_plan,
chunk_results=chunk_results,
stitched_segments=stitched,
zerogpu_timing={
"gpu_window_seconds": round(total_gpu_window_seconds, 4),
"inference_seconds": round(total_inference_seconds, 4),
},
)