Ratnesh-dev's picture
Add Schema Documentation To API
c7d2aa0
import time
import uuid
import warnings
from collections import OrderedDict
from typing import Any
import gradio as gr
import spaces
from src.constants import PARAKEET_V3
from src.diarization_service import run_chunked_diarization
from src.merge_service import merge_parakeet_pyannote_outputs
from src.models.parakeet_model import preload_parakeet_model, run_parakeet
from src.models.pyannote_community_model import preload_pyannote_pipeline, run_pyannote_community_chunk
from src.utils import get_audio_duration_seconds
# Suppress a known deprecation warning emitted by a transitive dependency in spaces.
warnings.filterwarnings(
"ignore",
message=r"`torch\.distributed\.reduce_op` is deprecated, please use `torch\.distributed\.ReduceOp` instead",
category=FutureWarning,
)
_PRELOAD_ERRORS: dict[str, str] = {}
_DEBUG_RUNS: "OrderedDict[str, dict[str, Any]]" = OrderedDict()
_MAX_DEBUG_RUNS = 10
_LAST_DEBUG_RUN_ID: str | None = None
RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA: dict[str, Any] = {
"type": "object",
"description": "Merged transcript output from Parakeet (word timestamps) and Pyannote diarization.",
"properties": {
"summary": {
"type": "object",
"properties": {
"diarization_key_used": {"type": "string", "example": "exclusive_speaker_diarization"},
"parakeet_word_count": {"type": "integer"},
"pyannote_segment_count": {"type": "integer"},
"turn_count": {"type": "integer"},
"assigned_word_count": {"type": "integer"},
"unassigned_word_count": {"type": "integer"},
},
"required": [
"diarization_key_used",
"parakeet_word_count",
"pyannote_segment_count",
"turn_count",
"assigned_word_count",
"unassigned_word_count",
],
},
"turns": {
"type": "array",
"items": {
"type": "object",
"properties": {
"speaker": {"type": "string", "example": "SPEAKER_02"},
"start": {"type": "number", "example": 40.72},
"end": {"type": "number", "example": 514.0},
"text": {"type": "string"},
},
"required": ["speaker", "start", "end", "text"],
},
},
"transcript_text": {"type": "string"},
},
"required": ["summary", "turns", "transcript_text"],
}
RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE: dict[str, Any] = {
"summary": {
"diarization_key_used": "exclusive_speaker_diarization",
"parakeet_word_count": 1234,
"pyannote_segment_count": 42,
"turn_count": 39,
"assigned_word_count": 1219,
"unassigned_word_count": 15,
},
"turns": [
{
"speaker": "SPEAKER_00",
"start": 0.0,
"end": 12.34,
"text": "Good morning and welcome to the earnings call.",
},
{
"speaker": "SPEAKER_01",
"start": 12.34,
"end": 19.02,
"text": "Thank you. Let us begin with quarterly highlights.",
},
],
"transcript_text": "[0.00 - 12.34] SPEAKER_00: Good morning ...",
}
RUN_COMPLETE_PIPELINE_INPUT_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"audio_file": {"type": "file", "description": "Audio file upload"},
"huggingface_token": {"type": "string", "description": "HF access token for pyannote model"},
},
"required": ["audio_file", "huggingface_token"],
}
def _preload_model(model_label: str, preload_fn) -> None:
try:
preload_fn()
except Exception as exc:
_PRELOAD_ERRORS[model_label] = str(exc)
def _raise_preload_error_if_any(model_label: str) -> None:
message = _PRELOAD_ERRORS.get(model_label)
if message:
raise gr.Error(
f"Model preload failed for {model_label}. "
"Check startup logs and dependencies. "
f"Details: {message}"
)
def _store_debug_payload(payload: dict[str, Any]) -> str:
global _LAST_DEBUG_RUN_ID
run_id = str(uuid.uuid4())
_DEBUG_RUNS[run_id] = payload
_LAST_DEBUG_RUN_ID = run_id
while len(_DEBUG_RUNS) > _MAX_DEBUG_RUNS:
_DEBUG_RUNS.popitem(last=False)
return run_id
def _parse_main_request(
audio_file: str | None,
huggingface_token: str | None,
) -> None:
if audio_file is None:
raise gr.Error("No audio file submitted. Upload an audio file first.")
if not huggingface_token or not huggingface_token.strip():
raise gr.Error("huggingface_token is required for pyannote/speaker-diarization-community-1.")
# Global setup (outside @spaces.GPU) so setup cost is not charged to ZeroGPU inference window.
_preload_model(PARAKEET_V3, preload_parakeet_model)
# Pyannote preload is best-effort at startup because token is provided per request.
preload_pyannote_pipeline(strict=False)
@spaces.GPU(duration=120)
def _gpu_infer_parakeet(audio_file: str, duration_seconds: float | None):
gpu_start = time.perf_counter()
result = run_parakeet(
audio_file=audio_file,
language=None,
model_options={},
duration_seconds=duration_seconds,
)
gpu_end = time.perf_counter()
return {
"raw_output": result["raw_output"],
"zerogpu_timing": {
"gpu_window_seconds": round(gpu_end - gpu_start, 4),
**result.get("timing", {}),
},
}
@spaces.GPU(duration=120)
def _gpu_infer_pyannote_chunk(audio_file: str, model_options: dict[str, Any]):
gpu_start = time.perf_counter()
result = run_pyannote_community_chunk(
audio_file=audio_file,
model_options=model_options,
)
gpu_end = time.perf_counter()
return {
"raw_output": result["raw_output"],
"zerogpu_timing": {
"gpu_window_seconds": round(gpu_end - gpu_start, 4),
**result.get("timing", {}),
},
}
def run_complete_pipeline(
audio_file: str,
huggingface_token: str,
):
_parse_main_request(audio_file, huggingface_token)
_raise_preload_error_if_any(PARAKEET_V3)
started_at = time.perf_counter()
duration_seconds = get_audio_duration_seconds(audio_file)
# 1) Parakeet transcription on ZeroGPU.
parakeet_gpu_result = _gpu_infer_parakeet(
audio_file=audio_file,
duration_seconds=duration_seconds,
)
parakeet_response = {
"model": PARAKEET_V3,
"task": "transcribe",
"audio_file": str(audio_file),
"postprocess_prompt": None,
"model_options": {},
"zerogpu_timing": parakeet_gpu_result["zerogpu_timing"],
"raw_output": parakeet_gpu_result["raw_output"],
"timestamp_granularity": "word",
}
# 2) Pyannote diarization on ZeroGPU (chunked only when needed).
pyannote_model_options = {
"hf_token": huggingface_token,
"long_audio_chunk_threshold_s": 7200,
"chunk_duration_s": 7200,
"chunk_overlap_s": 0,
}
preload_pyannote_pipeline(model_options=pyannote_model_options, strict=True)
pyannote_response = run_chunked_diarization(
audio_file=audio_file,
model_options=pyannote_model_options,
gpu_chunk_runner=_gpu_infer_pyannote_chunk,
)
# 3) CPU-side postprocessing outside ZeroGPU.
merged_transcript = merge_parakeet_pyannote_outputs(
parakeet_response=parakeet_response,
pyannote_response=pyannote_response,
diarization_key="exclusive_speaker_diarization",
)
total_gpu_window_seconds = float(parakeet_response["zerogpu_timing"].get("gpu_window_seconds", 0.0)) + float(
pyannote_response.get("zerogpu_timing", {}).get("gpu_window_seconds", 0.0)
)
total_inference_seconds = float(parakeet_response["zerogpu_timing"].get("inference_seconds", 0.0)) + float(
pyannote_response.get("zerogpu_timing", {}).get("inference_seconds", 0.0)
)
finished_at = time.perf_counter()
debug_payload = {
"pipeline_timing": {
"total_wall_clock_seconds": round(finished_at - started_at, 4),
"zerogpu_gpu_window_seconds_total": round(total_gpu_window_seconds, 4),
"zerogpu_inference_seconds_total": round(total_inference_seconds, 4),
},
"inputs": {
"audio_file": str(audio_file),
"huggingface_token_provided": bool(huggingface_token),
},
"parakeet_response": parakeet_response,
"pyannote_response": pyannote_response,
"merged_transcript": merged_transcript,
}
_store_debug_payload(debug_payload)
# Return merged transcript JSON (OpenAI cleanup is intentionally local/off-space).
return merged_transcript
def get_debug_output(run_id: str | None):
if run_id and run_id.strip():
payload = _DEBUG_RUNS.get(run_id.strip())
if payload is None:
raise gr.Error(f"Unknown run_id: {run_id}")
return {"run_id": run_id.strip(), "debug": payload}
if _LAST_DEBUG_RUN_ID is None:
raise gr.Error("No debug payload available yet. Run /run_complete_pipeline first.")
return {"run_id": _LAST_DEBUG_RUN_ID, "debug": _DEBUG_RUNS[_LAST_DEBUG_RUN_ID]}
def get_run_complete_pipeline_schema() -> dict[str, Any]:
return {
"api_name": "/run_complete_pipeline",
"input_schema": RUN_COMPLETE_PIPELINE_INPUT_SCHEMA,
"output_schema": RUN_COMPLETE_PIPELINE_OUTPUT_SCHEMA,
"output_example": RUN_COMPLETE_PIPELINE_OUTPUT_EXAMPLE,
"notes": [
"Use /get_debug_output to fetch raw model payloads and timing.",
"The production route returns only merged transcript JSON.",
],
}
with gr.Blocks(title="Parakeet + Pyannote Pipeline") as demo:
gr.Markdown(
"# End-to-end transcript pipeline\n"
"Runs Parakeet transcription, Pyannote diarization, then merges into a combined transcript JSON."
)
audio_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Audio file",
)
huggingface_token = gr.Textbox(
label="HuggingFace token",
type="password",
)
run_btn = gr.Button("Run full pipeline")
output = gr.JSON(label="Combined transcript JSON")
run_btn.click(
fn=run_complete_pipeline,
inputs=[audio_file, huggingface_token],
outputs=output,
api_name="run_complete_pipeline",
api_description=(
"Run Parakeet + Pyannote and return merged transcript JSON.\n"
"Response shape:\n"
"{\n"
' "summary": {\n'
' "diarization_key_used": str,\n'
' "parakeet_word_count": int,\n'
' "pyannote_segment_count": int,\n'
' "turn_count": int,\n'
' "assigned_word_count": int,\n'
' "unassigned_word_count": int\n'
" },\n"
' "turns": [{"speaker": str, "start": float, "end": float, "text": str}],\n'
' "transcript_text": str\n'
"}\n"
"For full machine-readable schema + example, call /get_run_complete_pipeline_schema."
),
)
with gr.Row():
debug_run_id = gr.Textbox(label="Debug run_id (optional)")
debug_btn = gr.Button("Get debug output")
debug_output = gr.JSON(label="Debug output (raw + benchmark)")
debug_btn.click(
fn=get_debug_output,
inputs=[debug_run_id],
outputs=debug_output,
api_name="get_debug_output",
api_description=(
"Return latest (or selected) debug payload including raw Parakeet/Pyannote outputs "
"and aggregated pipeline timing."
),
)
with gr.Row(visible=False):
schema_btn = gr.Button("get_run_complete_pipeline_schema")
schema_output = gr.JSON(label="run_complete_pipeline schema", visible=False)
schema_btn.click(
fn=get_run_complete_pipeline_schema,
inputs=None,
outputs=schema_output,
api_name="get_run_complete_pipeline_schema",
api_description="Return input/output schema contract for /run_complete_pipeline.",
)
demo.queue(default_concurrency_limit=1).launch(ssr_mode=False, theme=gr.themes.Ocean())