| import time |
| import uuid |
| import warnings |
| from collections import OrderedDict |
| from typing import Any |
|
|
| import gradio as gr |
| import spaces |
|
|
| from src.constants import PARAKEET_V3 |
| from src.diarization_service import run_chunked_diarization |
| from src.merge_service import merge_parakeet_pyannote_outputs |
| from src.models.parakeet_model import preload_parakeet_model, run_parakeet |
| from src.models.pyannote_community_model import preload_pyannote_pipeline, run_pyannote_community_chunk |
| from src.utils import get_audio_duration_seconds |
|
|
| |
| warnings.filterwarnings( |
| "ignore", |
| message=r"`torch\.distributed\.reduce_op` is deprecated, please use `torch\.distributed\.ReduceOp` instead", |
| category=FutureWarning, |
| ) |
|
|
| _PRELOAD_ERRORS: dict[str, str] = {} |
| _DEBUG_RUNS: "OrderedDict[str, dict[str, Any]]" = OrderedDict() |
| _MAX_DEBUG_RUNS = 10 |
| _LAST_DEBUG_RUN_ID: str | None = None |
|
|
| RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA: dict[str, Any] = { |
| "type": "object", |
| "description": "Merged transcript output from Parakeet (word timestamps) and Pyannote diarization.", |
| "properties": { |
| "summary": { |
| "type": "object", |
| "properties": { |
| "diarization_key_used": {"type": "string", "example": "exclusive_speaker_diarization"}, |
| "parakeet_word_count": {"type": "integer"}, |
| "pyannote_segment_count": {"type": "integer"}, |
| "turn_count": {"type": "integer"}, |
| "assigned_word_count": {"type": "integer"}, |
| "unassigned_word_count": {"type": "integer"}, |
| }, |
| "required": [ |
| "diarization_key_used", |
| "parakeet_word_count", |
| "pyannote_segment_count", |
| "turn_count", |
| "assigned_word_count", |
| "unassigned_word_count", |
| ], |
| }, |
| "turns": { |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "speaker": {"type": "string", "example": "SPEAKER_02"}, |
| "start": {"type": "number", "example": 40.72}, |
| "end": {"type": "number", "example": 514.0}, |
| "text": {"type": "string"}, |
| }, |
| "required": ["speaker", "start", "end", "text"], |
| }, |
| }, |
| "transcript_text": {"type": "string"}, |
| }, |
| "required": ["summary", "turns", "transcript_text"], |
| } |
|
|
| RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE: dict[str, Any] = { |
| "summary": { |
| "diarization_key_used": "exclusive_speaker_diarization", |
| "parakeet_word_count": 1234, |
| "pyannote_segment_count": 42, |
| "turn_count": 39, |
| "assigned_word_count": 1219, |
| "unassigned_word_count": 15, |
| }, |
| "turns": [ |
| { |
| "speaker": "SPEAKER_00", |
| "start": 0.0, |
| "end": 12.34, |
| "text": "Good morning and welcome to the earnings call.", |
| }, |
| { |
| "speaker": "SPEAKER_01", |
| "start": 12.34, |
| "end": 19.02, |
| "text": "Thank you. Let us begin with quarterly highlights.", |
| }, |
| ], |
| "transcript_text": "[0.00 - 12.34] SPEAKER_00: Good morning ...", |
| } |
|
|
| RUN_COMPLETE_PIPELINE_INPUT_SCHEMA: dict[str, Any] = { |
| "type": "object", |
| "properties": { |
| "audio_file": {"type": "file", "description": "Audio file upload"}, |
| "huggingface_token": {"type": "string", "description": "HF access token for pyannote model"}, |
| }, |
| "required": ["audio_file", "huggingface_token"], |
| } |
|
|
|
|
| def _preload_model(model_label: str, preload_fn) -> None: |
| try: |
| preload_fn() |
| except Exception as exc: |
| _PRELOAD_ERRORS[model_label] = str(exc) |
|
|
|
|
| def _raise_preload_error_if_any(model_label: str) -> None: |
| message = _PRELOAD_ERRORS.get(model_label) |
| if message: |
| raise gr.Error( |
| f"Model preload failed for {model_label}. " |
| "Check startup logs and dependencies. " |
| f"Details: {message}" |
| ) |
|
|
|
|
| def _store_debug_payload(payload: dict[str, Any]) -> str: |
| global _LAST_DEBUG_RUN_ID |
| run_id = str(uuid.uuid4()) |
| _DEBUG_RUNS[run_id] = payload |
| _LAST_DEBUG_RUN_ID = run_id |
| while len(_DEBUG_RUNS) > _MAX_DEBUG_RUNS: |
| _DEBUG_RUNS.popitem(last=False) |
| return run_id |
|
|
|
|
| def _parse_main_request( |
| audio_file: str | None, |
| huggingface_token: str | None, |
| ) -> None: |
| if audio_file is None: |
| raise gr.Error("No audio file submitted. Upload an audio file first.") |
| if not huggingface_token or not huggingface_token.strip(): |
| raise gr.Error("huggingface_token is required for pyannote/speaker-diarization-community-1.") |
|
|
|
|
| |
| _preload_model(PARAKEET_V3, preload_parakeet_model) |
| |
| preload_pyannote_pipeline(strict=False) |
|
|
|
|
| @spaces.GPU(duration=120) |
| def _gpu_infer_parakeet(audio_file: str, duration_seconds: float | None): |
| gpu_start = time.perf_counter() |
| result = run_parakeet( |
| audio_file=audio_file, |
| language=None, |
| model_options={}, |
| duration_seconds=duration_seconds, |
| ) |
| gpu_end = time.perf_counter() |
| return { |
| "raw_output": result["raw_output"], |
| "zerogpu_timing": { |
| "gpu_window_seconds": round(gpu_end - gpu_start, 4), |
| **result.get("timing", {}), |
| }, |
| } |
|
|
|
|
| @spaces.GPU(duration=120) |
| def _gpu_infer_pyannote_chunk(audio_file: str, model_options: dict[str, Any]): |
| gpu_start = time.perf_counter() |
| result = run_pyannote_community_chunk( |
| audio_file=audio_file, |
| model_options=model_options, |
| ) |
| gpu_end = time.perf_counter() |
| return { |
| "raw_output": result["raw_output"], |
| "zerogpu_timing": { |
| "gpu_window_seconds": round(gpu_end - gpu_start, 4), |
| **result.get("timing", {}), |
| }, |
| } |
|
|
|
|
| def run_complete_pipeline( |
| audio_file: str, |
| huggingface_token: str, |
| ): |
| _parse_main_request(audio_file, huggingface_token) |
| _raise_preload_error_if_any(PARAKEET_V3) |
|
|
| started_at = time.perf_counter() |
| duration_seconds = get_audio_duration_seconds(audio_file) |
|
|
| |
| parakeet_gpu_result = _gpu_infer_parakeet( |
| audio_file=audio_file, |
| duration_seconds=duration_seconds, |
| ) |
| parakeet_response = { |
| "model": PARAKEET_V3, |
| "task": "transcribe", |
| "audio_file": str(audio_file), |
| "postprocess_prompt": None, |
| "model_options": {}, |
| "zerogpu_timing": parakeet_gpu_result["zerogpu_timing"], |
| "raw_output": parakeet_gpu_result["raw_output"], |
| "timestamp_granularity": "word", |
| } |
|
|
| |
| pyannote_model_options = { |
| "hf_token": huggingface_token, |
| "long_audio_chunk_threshold_s": 7200, |
| "chunk_duration_s": 7200, |
| "chunk_overlap_s": 0, |
| } |
| preload_pyannote_pipeline(model_options=pyannote_model_options, strict=True) |
| pyannote_response = run_chunked_diarization( |
| audio_file=audio_file, |
| model_options=pyannote_model_options, |
| gpu_chunk_runner=_gpu_infer_pyannote_chunk, |
| ) |
|
|
| |
| merged_transcript = merge_parakeet_pyannote_outputs( |
| parakeet_response=parakeet_response, |
| pyannote_response=pyannote_response, |
| diarization_key="exclusive_speaker_diarization", |
| ) |
|
|
| total_gpu_window_seconds = float(parakeet_response["zerogpu_timing"].get("gpu_window_seconds", 0.0)) + float( |
| pyannote_response.get("zerogpu_timing", {}).get("gpu_window_seconds", 0.0) |
| ) |
| total_inference_seconds = float(parakeet_response["zerogpu_timing"].get("inference_seconds", 0.0)) + float( |
| pyannote_response.get("zerogpu_timing", {}).get("inference_seconds", 0.0) |
| ) |
|
|
| finished_at = time.perf_counter() |
| debug_payload = { |
| "pipeline_timing": { |
| "total_wall_clock_seconds": round(finished_at - started_at, 4), |
| "zerogpu_gpu_window_seconds_total": round(total_gpu_window_seconds, 4), |
| "zerogpu_inference_seconds_total": round(total_inference_seconds, 4), |
| }, |
| "inputs": { |
| "audio_file": str(audio_file), |
| "huggingface_token_provided": bool(huggingface_token), |
| }, |
| "parakeet_response": parakeet_response, |
| "pyannote_response": pyannote_response, |
| "merged_transcript": merged_transcript, |
| } |
| _store_debug_payload(debug_payload) |
|
|
| |
| return merged_transcript |
|
|
|
|
| def get_debug_output(run_id: str | None): |
| if run_id and run_id.strip(): |
| payload = _DEBUG_RUNS.get(run_id.strip()) |
| if payload is None: |
| raise gr.Error(f"Unknown run_id: {run_id}") |
| return {"run_id": run_id.strip(), "debug": payload} |
|
|
| if _LAST_DEBUG_RUN_ID is None: |
| raise gr.Error("No debug payload available yet. Run /run_complete_pipeline first.") |
| return {"run_id": _LAST_DEBUG_RUN_ID, "debug": _DEBUG_RUNS[_LAST_DEBUG_RUN_ID]} |
|
|
|
|
| def get_run_complete_pipeline_schema() -> dict[str, Any]: |
| return { |
| "api_name": "/run_complete_pipeline", |
| "input_schema": RUN_COMPLETE_PIPELINE_INPUT_SCHEMA, |
| "output_schema": RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA, |
| "output_example": RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE, |
| "notes": [ |
| "Use /get_debug_output to fetch raw model payloads and timing.", |
| "The production route returns only merged transcript JSON.", |
| ], |
| } |
|
|
|
|
| with gr.Blocks(title="Parakeet + Pyannote Pipeline") as demo: |
| gr.Markdown( |
| "# End-to-end transcript pipeline\n" |
| "Runs Parakeet transcription, Pyannote diarization, then merges into a combined transcript JSON." |
| ) |
|
|
| audio_file = gr.Audio( |
| sources=["upload"], |
| type="filepath", |
| label="Audio file", |
| ) |
| huggingface_token = gr.Textbox( |
| label="HuggingFace token", |
| type="password", |
| ) |
| run_btn = gr.Button("Run full pipeline") |
| output = gr.JSON(label="Combined transcript JSON") |
|
|
| run_btn.click( |
| fn=run_complete_pipeline, |
| inputs=[audio_file, huggingface_token], |
| outputs=output, |
| api_name="run_complete_pipeline", |
| api_description=( |
| "Run Parakeet + Pyannote and return merged transcript JSON.\n" |
| "Response shape:\n" |
| "{\n" |
| ' "summary": {\n' |
| ' "diarization_key_used": str,\n' |
| ' "parakeet_word_count": int,\n' |
| ' "pyannote_segment_count": int,\n' |
| ' "turn_count": int,\n' |
| ' "assigned_word_count": int,\n' |
| ' "unassigned_word_count": int\n' |
| " },\n" |
| ' "turns": [{"speaker": str, "start": float, "end": float, "text": str}],\n' |
| ' "transcript_text": str\n' |
| "}\n" |
| "For full machine-readable schema + example, call /get_run_complete_pipeline_schema." |
| ), |
| ) |
|
|
| with gr.Row(): |
| debug_run_id = gr.Textbox(label="Debug run_id (optional)") |
| debug_btn = gr.Button("Get debug output") |
| debug_output = gr.JSON(label="Debug output (raw + benchmark)") |
|
|
| debug_btn.click( |
| fn=get_debug_output, |
| inputs=[debug_run_id], |
| outputs=debug_output, |
| api_name="get_debug_output", |
| api_description=( |
| "Return latest (or selected) debug payload including raw Parakeet/Pyannote outputs " |
| "and aggregated pipeline timing." |
| ), |
| ) |
|
|
| with gr.Row(visible=False): |
| schema_btn = gr.Button("get_run_complete_pipeline_schema") |
| schema_output = gr.JSON(label="run_complete_pipeline schema", visible=False) |
| schema_btn.click( |
| fn=get_run_complete_pipeline_schema, |
| inputs=None, |
| outputs=schema_output, |
| api_name="get_run_complete_pipeline_schema", |
| api_description="Return input/output schema contract for /run_complete_pipeline.", |
| ) |
|
|
|
|
| demo.queue(default_concurrency_limit=1).launch(ssr_mode=False, theme=gr.themes.Ocean()) |
|
|