| import tempfile |
| from typing import Any |
|
|
| import gradio as gr |
|
|
| from src.constants import PYANNOTE_COMMUNITY_1 |
| from src.utils import build_audio_chunk_plan, extract_audio_clip, get_audio_duration_seconds, parse_model_options |
|
|
|
|
| def parse_and_validate_diarization_request( |
| audio_file: str, |
| model_options_json: str, |
| ) -> dict: |
| if audio_file is None: |
| raise gr.Error("No audio file submitted. Upload an audio file first.") |
|
|
| return parse_model_options(model_options_json) |
|
|
|
|
| def sanitize_model_options_for_response(model_options: dict[str, Any]) -> dict[str, Any]: |
| redacted = dict(model_options) |
| for key in ("hf_token", "huggingface_token"): |
| if key in redacted: |
| redacted[key] = "***redacted***" |
| return redacted |
|
|
|
|
| def normalize_segments_with_offset(segments: list[dict[str, Any]], offset: float) -> list[dict[str, Any]]: |
| normalized = [] |
| for seg in segments: |
| start = float(seg.get("start", 0.0)) + offset |
| end = float(seg.get("end", 0.0)) + offset |
| normalized.append({ |
| **seg, |
| "start": round(start, 4), |
| "end": round(end, 4), |
| }) |
| return normalized |
|
|
|
|
| def merge_adjacent_segments( |
| segments: list[dict[str, Any]], |
| max_gap_s: float = 0.2, |
| ) -> list[dict[str, Any]]: |
| if not segments: |
| return [] |
|
|
| ordered = sorted(segments, key=lambda x: (float(x["start"]), float(x["end"]))) |
| merged = [ordered[0]] |
| for seg in ordered[1:]: |
| prev = merged[-1] |
| same_speaker = str(prev.get("speaker")) == str(seg.get("speaker")) |
| gap = float(seg["start"]) - float(prev["end"]) |
| if same_speaker and gap <= max_gap_s: |
| prev["end"] = round(max(float(prev["end"]), float(seg["end"])), 4) |
| else: |
| merged.append(seg) |
| return merged |
|
|
|
|
| def build_diarization_response( |
| audio_file: str, |
| model_options: dict, |
| duration_seconds: float | None, |
| chunk_plan: list[dict[str, Any]], |
| chunk_results: list[dict[str, Any]], |
| stitched_segments: dict[str, list[dict[str, Any]]], |
| zerogpu_timing: dict, |
| ) -> dict: |
| return { |
| "model": PYANNOTE_COMMUNITY_1, |
| "audio_file": str(audio_file), |
| "duration_seconds": duration_seconds, |
| "model_options": sanitize_model_options_for_response(model_options), |
| "long_audio": { |
| "chunk_count": len(chunk_plan), |
| "chunk_plan": chunk_plan, |
| "chunked": len(chunk_plan) > 1, |
| }, |
| "zerogpu_timing": zerogpu_timing, |
| "raw_output": { |
| "chunk_results": chunk_results, |
| "stitched": stitched_segments, |
| }, |
| } |
|
|
|
|
| def run_chunked_diarization( |
| audio_file: str, |
| model_options: dict, |
| gpu_chunk_runner, |
| ) -> dict: |
| duration_seconds = get_audio_duration_seconds(audio_file) |
|
|
| |
| chunk_threshold_s = float(model_options.get("long_audio_chunk_threshold_s", 7200)) |
| chunk_duration_s = float(model_options.get("chunk_duration_s", 7200)) |
| chunk_overlap_s = float(model_options.get("chunk_overlap_s", 0)) |
| merge_gap_s = float(model_options.get("merge_gap_s", 0.2)) |
|
|
| if chunk_threshold_s <= 0: |
| raise gr.Error("long_audio_chunk_threshold_s must be > 0") |
| if chunk_duration_s <= 0: |
| raise gr.Error("chunk_duration_s must be > 0") |
| if chunk_overlap_s < 0: |
| raise gr.Error("chunk_overlap_s must be >= 0") |
| if chunk_overlap_s >= chunk_duration_s: |
| raise gr.Error("chunk_overlap_s must be smaller than chunk_duration_s") |
|
|
| if duration_seconds is None or duration_seconds <= chunk_threshold_s: |
| chunk_plan = [{"index": 0, "start": 0.0, "end": duration_seconds, "duration": duration_seconds}] |
| else: |
| chunk_plan = build_audio_chunk_plan( |
| audio_file=audio_file, |
| chunk_duration_s=chunk_duration_s, |
| chunk_overlap_s=chunk_overlap_s, |
| ) |
|
|
| chunk_results: list[dict[str, Any]] = [] |
| stitched_standard: list[dict[str, Any]] = [] |
| stitched_exclusive: list[dict[str, Any]] = [] |
| total_gpu_window_seconds = 0.0 |
| total_inference_seconds = 0.0 |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| for chunk in chunk_plan: |
| start = float(chunk["start"]) |
| duration_raw = chunk.get("duration") |
| duration = None if duration_raw is None else float(duration_raw) |
| if duration is not None and duration <= 0: |
| continue |
|
|
| if len(chunk_plan) == 1: |
| chunk_audio_file = audio_file |
| else: |
| if duration is None: |
| raise gr.Error( |
| "Chunk duration is unknown. Ensure ffprobe is available to process long audio chunking." |
| ) |
| chunk_audio_file = extract_audio_clip( |
| source_audio_file=audio_file, |
| start_seconds=start, |
| duration_seconds=duration, |
| tmpdir=tmpdir, |
| ) |
|
|
| result = gpu_chunk_runner( |
| audio_file=chunk_audio_file, |
| model_options=model_options, |
| ) |
|
|
| total_gpu_window_seconds += float(result["zerogpu_timing"].get("gpu_window_seconds", 0.0)) |
| total_inference_seconds += float(result["zerogpu_timing"].get("inference_seconds", 0.0)) |
|
|
| raw = result["raw_output"] |
| standard_segments = raw.get("speaker_diarization", {}).get("segments", []) |
| exclusive_segments = raw.get("exclusive_speaker_diarization", {}).get("segments", []) |
|
|
| standard_shifted = normalize_segments_with_offset(standard_segments, offset=start) |
| exclusive_shifted = normalize_segments_with_offset(exclusive_segments, offset=start) |
|
|
| chunk_results.append( |
| { |
| "chunk": chunk, |
| "zerogpu_timing": result["zerogpu_timing"], |
| "raw_output": raw, |
| "shifted_segments": { |
| "speaker_diarization": standard_shifted, |
| "exclusive_speaker_diarization": exclusive_shifted, |
| }, |
| } |
| ) |
|
|
| stitched_standard.extend(standard_shifted) |
| stitched_exclusive.extend(exclusive_shifted) |
|
|
| stitched = { |
| "speaker_diarization": merge_adjacent_segments(stitched_standard, max_gap_s=merge_gap_s), |
| "exclusive_speaker_diarization": merge_adjacent_segments(stitched_exclusive, max_gap_s=merge_gap_s), |
| } |
|
|
| return build_diarization_response( |
| audio_file=audio_file, |
| model_options=model_options, |
| duration_seconds=duration_seconds, |
| chunk_plan=chunk_plan, |
| chunk_results=chunk_results, |
| stitched_segments=stitched, |
| zerogpu_timing={ |
| "gpu_window_seconds": round(total_gpu_window_seconds, 4), |
| "inference_seconds": round(total_inference_seconds, 4), |
| }, |
| ) |
|
|