import tempfile from typing import Any import gradio as gr from src.constants import PYANNOTE_COMMUNITY_1 from src.utils import build_audio_chunk_plan, extract_audio_clip, get_audio_duration_seconds, parse_model_options def parse_and_validate_diarization_request( audio_file: str, model_options_json: str, ) -> dict: if audio_file is None: raise gr.Error("No audio file submitted. Upload an audio file first.") return parse_model_options(model_options_json) def sanitize_model_options_for_response(model_options: dict[str, Any]) -> dict[str, Any]: redacted = dict(model_options) for key in ("hf_token", "huggingface_token"): if key in redacted: redacted[key] = "***redacted***" return redacted def normalize_segments_with_offset(segments: list[dict[str, Any]], offset: float) -> list[dict[str, Any]]: normalized = [] for seg in segments: start = float(seg.get("start", 0.0)) + offset end = float(seg.get("end", 0.0)) + offset normalized.append({ **seg, "start": round(start, 4), "end": round(end, 4), }) return normalized def merge_adjacent_segments( segments: list[dict[str, Any]], max_gap_s: float = 0.2, ) -> list[dict[str, Any]]: if not segments: return [] ordered = sorted(segments, key=lambda x: (float(x["start"]), float(x["end"]))) merged = [ordered[0]] for seg in ordered[1:]: prev = merged[-1] same_speaker = str(prev.get("speaker")) == str(seg.get("speaker")) gap = float(seg["start"]) - float(prev["end"]) if same_speaker and gap <= max_gap_s: prev["end"] = round(max(float(prev["end"]), float(seg["end"])), 4) else: merged.append(seg) return merged def build_diarization_response( audio_file: str, model_options: dict, duration_seconds: float | None, chunk_plan: list[dict[str, Any]], chunk_results: list[dict[str, Any]], stitched_segments: dict[str, list[dict[str, Any]]], zerogpu_timing: dict, ) -> dict: return { "model": PYANNOTE_COMMUNITY_1, "audio_file": str(audio_file), "duration_seconds": duration_seconds, "model_options": sanitize_model_options_for_response(model_options), "long_audio": { "chunk_count": len(chunk_plan), "chunk_plan": chunk_plan, "chunked": len(chunk_plan) > 1, }, "zerogpu_timing": zerogpu_timing, "raw_output": { "chunk_results": chunk_results, "stitched": stitched_segments, }, } def run_chunked_diarization( audio_file: str, model_options: dict, gpu_chunk_runner, ) -> dict: duration_seconds = get_audio_duration_seconds(audio_file) # Default to single-pass diarization up to 2 hours on ZeroGPU. chunk_threshold_s = float(model_options.get("long_audio_chunk_threshold_s", 7200)) chunk_duration_s = float(model_options.get("chunk_duration_s", 7200)) chunk_overlap_s = float(model_options.get("chunk_overlap_s", 0)) merge_gap_s = float(model_options.get("merge_gap_s", 0.2)) if chunk_threshold_s <= 0: raise gr.Error("long_audio_chunk_threshold_s must be > 0") if chunk_duration_s <= 0: raise gr.Error("chunk_duration_s must be > 0") if chunk_overlap_s < 0: raise gr.Error("chunk_overlap_s must be >= 0") if chunk_overlap_s >= chunk_duration_s: raise gr.Error("chunk_overlap_s must be smaller than chunk_duration_s") if duration_seconds is None or duration_seconds <= chunk_threshold_s: chunk_plan = [{"index": 0, "start": 0.0, "end": duration_seconds, "duration": duration_seconds}] else: chunk_plan = build_audio_chunk_plan( audio_file=audio_file, chunk_duration_s=chunk_duration_s, chunk_overlap_s=chunk_overlap_s, ) chunk_results: list[dict[str, Any]] = [] stitched_standard: list[dict[str, Any]] = [] stitched_exclusive: list[dict[str, Any]] = [] total_gpu_window_seconds = 0.0 total_inference_seconds = 0.0 with tempfile.TemporaryDirectory() as tmpdir: for chunk in chunk_plan: start = float(chunk["start"]) duration_raw = chunk.get("duration") duration = None if duration_raw is None else float(duration_raw) if duration is not None and duration <= 0: continue if len(chunk_plan) == 1: chunk_audio_file = audio_file else: if duration is None: raise gr.Error( "Chunk duration is unknown. Ensure ffprobe is available to process long audio chunking." ) chunk_audio_file = extract_audio_clip( source_audio_file=audio_file, start_seconds=start, duration_seconds=duration, tmpdir=tmpdir, ) result = gpu_chunk_runner( audio_file=chunk_audio_file, model_options=model_options, ) total_gpu_window_seconds += float(result["zerogpu_timing"].get("gpu_window_seconds", 0.0)) total_inference_seconds += float(result["zerogpu_timing"].get("inference_seconds", 0.0)) raw = result["raw_output"] standard_segments = raw.get("speaker_diarization", {}).get("segments", []) exclusive_segments = raw.get("exclusive_speaker_diarization", {}).get("segments", []) standard_shifted = normalize_segments_with_offset(standard_segments, offset=start) exclusive_shifted = normalize_segments_with_offset(exclusive_segments, offset=start) chunk_results.append( { "chunk": chunk, "zerogpu_timing": result["zerogpu_timing"], "raw_output": raw, "shifted_segments": { "speaker_diarization": standard_shifted, "exclusive_speaker_diarization": exclusive_shifted, }, } ) stitched_standard.extend(standard_shifted) stitched_exclusive.extend(exclusive_shifted) stitched = { "speaker_diarization": merge_adjacent_segments(stitched_standard, max_gap_s=merge_gap_s), "exclusive_speaker_diarization": merge_adjacent_segments(stitched_exclusive, max_gap_s=merge_gap_s), } return build_diarization_response( audio_file=audio_file, model_options=model_options, duration_seconds=duration_seconds, chunk_plan=chunk_plan, chunk_results=chunk_results, stitched_segments=stitched, zerogpu_timing={ "gpu_window_seconds": round(total_gpu_window_seconds, 4), "inference_seconds": round(total_inference_seconds, 4), }, )