import time import uuid import warnings from collections import OrderedDict from typing import Any import gradio as gr import spaces from src.constants import PARAKEET_V3 from src.diarization_service import run_chunked_diarization from src.merge_service import merge_parakeet_pyannote_outputs from src.models.parakeet_model import preload_parakeet_model, run_parakeet from src.models.pyannote_community_model import preload_pyannote_pipeline, run_pyannote_community_chunk from src.utils import get_audio_duration_seconds # Suppress a known deprecation warning emitted by a transitive dependency in spaces. warnings.filterwarnings( "ignore", message=r"`torch\.distributed\.reduce_op` is deprecated, please use `torch\.distributed\.ReduceOp` instead", category=FutureWarning, ) _PRELOAD_ERRORS: dict[str, str] = {} _DEBUG_RUNS: "OrderedDict[str, dict[str, Any]]" = OrderedDict() _MAX_DEBUG_RUNS = 10 _LAST_DEBUG_RUN_ID: str | None = None RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA: dict[str, Any] = { "type": "object", "description": "Merged transcript output from Parakeet (word timestamps) and Pyannote diarization.", "properties": { "summary": { "type": "object", "properties": { "diarization_key_used": {"type": "string", "example": "exclusive_speaker_diarization"}, "parakeet_word_count": {"type": "integer"}, "pyannote_segment_count": {"type": "integer"}, "turn_count": {"type": "integer"}, "assigned_word_count": {"type": "integer"}, "unassigned_word_count": {"type": "integer"}, }, "required": [ "diarization_key_used", "parakeet_word_count", "pyannote_segment_count", "turn_count", "assigned_word_count", "unassigned_word_count", ], }, "turns": { "type": "array", "items": { "type": "object", "properties": { "speaker": {"type": "string", "example": "SPEAKER_02"}, "start": {"type": "number", "example": 40.72}, "end": {"type": "number", "example": 514.0}, "text": {"type": "string"}, }, "required": ["speaker", "start", "end", "text"], }, }, "transcript_text": {"type": "string"}, }, "required": ["summary", "turns", "transcript_text"], } RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE: dict[str, Any] = { "summary": { "diarization_key_used": "exclusive_speaker_diarization", "parakeet_word_count": 1234, "pyannote_segment_count": 42, "turn_count": 39, "assigned_word_count": 1219, "unassigned_word_count": 15, }, "turns": [ { "speaker": "SPEAKER_00", "start": 0.0, "end": 12.34, "text": "Good morning and welcome to the earnings call.", }, { "speaker": "SPEAKER_01", "start": 12.34, "end": 19.02, "text": "Thank you. Let us begin with quarterly highlights.", }, ], "transcript_text": "[0.00 - 12.34] SPEAKER_00: Good morning ...", } RUN_COMPLETE_PIPELINE_INPUT_SCHEMA: dict[str, Any] = { "type": "object", "properties": { "audio_file": {"type": "file", "description": "Audio file upload"}, "huggingface_token": {"type": "string", "description": "HF access token for pyannote model"}, }, "required": ["audio_file", "huggingface_token"], } def _preload_model(model_label: str, preload_fn) -> None: try: preload_fn() except Exception as exc: _PRELOAD_ERRORS[model_label] = str(exc) def _raise_preload_error_if_any(model_label: str) -> None: message = _PRELOAD_ERRORS.get(model_label) if message: raise gr.Error( f"Model preload failed for {model_label}. " "Check startup logs and dependencies. " f"Details: {message}" ) def _store_debug_payload(payload: dict[str, Any]) -> str: global _LAST_DEBUG_RUN_ID run_id = str(uuid.uuid4()) _DEBUG_RUNS[run_id] = payload _LAST_DEBUG_RUN_ID = run_id while len(_DEBUG_RUNS) > _MAX_DEBUG_RUNS: _DEBUG_RUNS.popitem(last=False) return run_id def _parse_main_request( audio_file: str | None, huggingface_token: str | None, ) -> None: if audio_file is None: raise gr.Error("No audio file submitted. Upload an audio file first.") if not huggingface_token or not huggingface_token.strip(): raise gr.Error("huggingface_token is required for pyannote/speaker-diarization-community-1.") # Global setup (outside @spaces.GPU) so setup cost is not charged to ZeroGPU inference window. _preload_model(PARAKEET_V3, preload_parakeet_model) # Pyannote preload is best-effort at startup because token is provided per request. preload_pyannote_pipeline(strict=False) @spaces.GPU(duration=120) def _gpu_infer_parakeet(audio_file: str, duration_seconds: float | None): gpu_start = time.perf_counter() result = run_parakeet( audio_file=audio_file, language=None, model_options={}, duration_seconds=duration_seconds, ) gpu_end = time.perf_counter() return { "raw_output": result["raw_output"], "zerogpu_timing": { "gpu_window_seconds": round(gpu_end - gpu_start, 4), **result.get("timing", {}), }, } @spaces.GPU(duration=120) def _gpu_infer_pyannote_chunk(audio_file: str, model_options: dict[str, Any]): gpu_start = time.perf_counter() result = run_pyannote_community_chunk( audio_file=audio_file, model_options=model_options, ) gpu_end = time.perf_counter() return { "raw_output": result["raw_output"], "zerogpu_timing": { "gpu_window_seconds": round(gpu_end - gpu_start, 4), **result.get("timing", {}), }, } def run_complete_pipeline( audio_file: str, huggingface_token: str, ): _parse_main_request(audio_file, huggingface_token) _raise_preload_error_if_any(PARAKEET_V3) started_at = time.perf_counter() duration_seconds = get_audio_duration_seconds(audio_file) # 1) Parakeet transcription on ZeroGPU. parakeet_gpu_result = _gpu_infer_parakeet( audio_file=audio_file, duration_seconds=duration_seconds, ) parakeet_response = { "model": PARAKEET_V3, "task": "transcribe", "audio_file": str(audio_file), "postprocess_prompt": None, "model_options": {}, "zerogpu_timing": parakeet_gpu_result["zerogpu_timing"], "raw_output": parakeet_gpu_result["raw_output"], "timestamp_granularity": "word", } # 2) Pyannote diarization on ZeroGPU (chunked only when needed). pyannote_model_options = { "hf_token": huggingface_token, "long_audio_chunk_threshold_s": 7200, "chunk_duration_s": 7200, "chunk_overlap_s": 0, } preload_pyannote_pipeline(model_options=pyannote_model_options, strict=True) pyannote_response = run_chunked_diarization( audio_file=audio_file, model_options=pyannote_model_options, gpu_chunk_runner=_gpu_infer_pyannote_chunk, ) # 3) CPU-side postprocessing outside ZeroGPU. merged_transcript = merge_parakeet_pyannote_outputs( parakeet_response=parakeet_response, pyannote_response=pyannote_response, diarization_key="exclusive_speaker_diarization", ) total_gpu_window_seconds = float(parakeet_response["zerogpu_timing"].get("gpu_window_seconds", 0.0)) + float( pyannote_response.get("zerogpu_timing", {}).get("gpu_window_seconds", 0.0) ) total_inference_seconds = float(parakeet_response["zerogpu_timing"].get("inference_seconds", 0.0)) + float( pyannote_response.get("zerogpu_timing", {}).get("inference_seconds", 0.0) ) finished_at = time.perf_counter() debug_payload = { "pipeline_timing": { "total_wall_clock_seconds": round(finished_at - started_at, 4), "zerogpu_gpu_window_seconds_total": round(total_gpu_window_seconds, 4), "zerogpu_inference_seconds_total": round(total_inference_seconds, 4), }, "inputs": { "audio_file": str(audio_file), "huggingface_token_provided": bool(huggingface_token), }, "parakeet_response": parakeet_response, "pyannote_response": pyannote_response, "merged_transcript": merged_transcript, } _store_debug_payload(debug_payload) # Return merged transcript JSON (OpenAI cleanup is intentionally local/off-space). return merged_transcript def get_debug_output(run_id: str | None): if run_id and run_id.strip(): payload = _DEBUG_RUNS.get(run_id.strip()) if payload is None: raise gr.Error(f"Unknown run_id: {run_id}") return {"run_id": run_id.strip(), "debug": payload} if _LAST_DEBUG_RUN_ID is None: raise gr.Error("No debug payload available yet. Run /run_complete_pipeline first.") return {"run_id": _LAST_DEBUG_RUN_ID, "debug": _DEBUG_RUNS[_LAST_DEBUG_RUN_ID]} def get_run_complete_pipeline_schema() -> dict[str, Any]: return { "api_name": "/run_complete_pipeline", "input_schema": RUN_COMPLETE_PIPELINE_INPUT_SCHEMA, "output_schema": RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA, "output_example": RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE, "notes": [ "Use /get_debug_output to fetch raw model payloads and timing.", "The production route returns only merged transcript JSON.", ], } with gr.Blocks(title="Parakeet + Pyannote Pipeline") as demo: gr.Markdown( "# End-to-end transcript pipeline\n" "Runs Parakeet transcription, Pyannote diarization, then merges into a combined transcript JSON." ) audio_file = gr.Audio( sources=["upload"], type="filepath", label="Audio file", ) huggingface_token = gr.Textbox( label="HuggingFace token", type="password", ) run_btn = gr.Button("Run full pipeline") output = gr.JSON(label="Combined transcript JSON") run_btn.click( fn=run_complete_pipeline, inputs=[audio_file, huggingface_token], outputs=output, api_name="run_complete_pipeline", api_description=( "Run Parakeet + Pyannote and return merged transcript JSON.\n" "Response shape:\n" "{\n" ' "summary": {\n' ' "diarization_key_used": str,\n' ' "parakeet_word_count": int,\n' ' "pyannote_segment_count": int,\n' ' "turn_count": int,\n' ' "assigned_word_count": int,\n' ' "unassigned_word_count": int\n' " },\n" ' "turns": [{"speaker": str, "start": float, "end": float, "text": str}],\n' ' "transcript_text": str\n' "}\n" "For full machine-readable schema + example, call /get_run_complete_pipeline_schema." ), ) with gr.Row(): debug_run_id = gr.Textbox(label="Debug run_id (optional)") debug_btn = gr.Button("Get debug output") debug_output = gr.JSON(label="Debug output (raw + benchmark)") debug_btn.click( fn=get_debug_output, inputs=[debug_run_id], outputs=debug_output, api_name="get_debug_output", api_description=( "Return latest (or selected) debug payload including raw Parakeet/Pyannote outputs " "and aggregated pipeline timing." ), ) with gr.Row(visible=False): schema_btn = gr.Button("get_run_complete_pipeline_schema") schema_output = gr.JSON(label="run_complete_pipeline schema", visible=False) schema_btn.click( fn=get_run_complete_pipeline_schema, inputs=None, outputs=schema_output, api_name="get_run_complete_pipeline_schema", api_description="Return input/output schema contract for /run_complete_pipeline.", ) demo.queue(default_concurrency_limit=1).launch(ssr_mode=False, theme=gr.themes.Ocean())