transcribe-diarize / local_api_benchmark.py
Ratnesh-dev's picture
Fix Parakeet Call Error And Remove Unused API Parameters
68728b4
import json
import time
from concurrent.futures import TimeoutError as FutureTimeoutError
from datetime import datetime
from pathlib import Path
from typing import Any
from gradio_client import Client, handle_file
from src.constants import (
PARAKEET_V3,
PYANNOTE_COMMUNITY_1,
WHISPER_LARGE_V3_TURBO,
)
MODEL_API_BY_LABEL = {
WHISPER_LARGE_V3_TURBO: "/transcribe_whisper_large_v3_turbo",
PARAKEET_V3: "/transcribe_parakeet_v3",
}
PYANNOTE_API_NAME = "/diarize_pyannote_community_1"
PARAKEET_API_NAME = "/transcribe_parakeet_v3"
def _safe_name(value: str) -> str:
return value.replace("/", "_").replace(" ", "_").replace("(", "").replace(")", "").lower()
def _to_model_options_json(model_options: str | dict[str, Any] | None) -> str | None:
if model_options is None:
return None
if isinstance(model_options, str):
return model_options
return json.dumps(model_options)
def _normalize_model_options_for_model(model_label: str, model_options: str | dict[str, Any] | None) -> str | dict[str, Any] | None:
if not isinstance(model_options, dict):
return model_options
normalized = dict(model_options)
if model_label == PARAKEET_V3:
normalized.pop("batch_size", None)
return normalized
def _leaderboard_rows(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
ok_items = [r for r in results if r.get("status") == "ok"]
def key_fn(item: dict[str, Any]) -> float:
payload = item.get("result") or {}
timing = payload.get("zerogpu_timing") or {}
value = timing.get("gpu_window_seconds")
return float("inf") if value is None else float(value)
ranked = sorted(ok_items, key=key_fn)
return [
{
"model": item["model"],
"api_name": item["api_name"],
"gpu_window_seconds": ((item.get("result") or {}).get("zerogpu_timing") or {}).get("gpu_window_seconds"),
"inference_seconds": ((item.get("result") or {}).get("zerogpu_timing") or {}).get("inference_seconds"),
"client_wall_clock_seconds": item.get("client_wall_clock_seconds"),
}
for item in ranked
]
def _save_benchmark_outputs(result_obj: dict[str, Any], output_dir: str | Path | None) -> dict[str, Any]:
root = Path(output_dir) if output_dir else Path("benchmark_outputs")
run_dir = root / datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir.mkdir(parents=True, exist_ok=True)
per_model_files: dict[str, str] = {}
for item in result_obj.get("results", []):
model = item.get("model", "unknown")
filename = f"{_safe_name(model)}.json"
file_path = run_dir / filename
payload = item.get("result") if item.get("status") == "ok" else item
file_path.write_text(json.dumps(payload, indent=2, ensure_ascii=False))
per_model_files[model] = str(file_path)
summary = {
"space": result_obj.get("space"),
"audio_file": result_obj.get("audio_file"),
"task": result_obj.get("task"),
"language": result_obj.get("language"),
"models": result_obj.get("models"),
"benchmark_timing": result_obj.get("benchmark_timing"),
"leaderboard_by_gpu_window_seconds": result_obj.get("leaderboard_by_gpu_window_seconds"),
"summary": [
{
"model": r.get("model"),
"status": r.get("status"),
"api_name": r.get("api_name"),
"client_wall_clock_seconds": r.get("client_wall_clock_seconds"),
"gpu_window_seconds": ((r.get("result") or {}).get("zerogpu_timing") or {}).get("gpu_window_seconds"),
"inference_seconds": ((r.get("result") or {}).get("zerogpu_timing") or {}).get("inference_seconds"),
"error": r.get("error"),
}
for r in result_obj.get("results", [])
],
}
stats_path = run_dir / "benchmark_stats.json"
stats_path.write_text(json.dumps(summary, indent=2, ensure_ascii=False))
return {
"run_dir": str(run_dir),
"per_model_files": per_model_files,
"benchmark_stats_file": str(stats_path),
}
def run_pyannote_api(
space: str,
audio_file: str,
model_options_by_model: dict[str, str | dict[str, Any]] | None = None,
model_options: str | dict[str, Any] | None = None,
hf_token: str | None = None,
request_timeout_s: float = 1800.0,
result_timeout_s: float | None = 7200.0,
save_outputs: bool = True,
output_dir: str | Path | None = None,
) -> dict[str, Any]:
"""Benchmark-style wrapper for a single pyannote diarization API call."""
client = Client(space, token=hf_token, httpx_kwargs={"timeout": request_timeout_s})
options_json = _to_model_options_json(model_options)
if model_options_by_model and PYANNOTE_COMMUNITY_1 in model_options_by_model:
options_json = _to_model_options_json(model_options_by_model[PYANNOTE_COMMUNITY_1])
started_at = time.perf_counter()
call_start = time.perf_counter()
try:
job = client.submit(
audio_file=handle_file(audio_file),
model_options_json=options_json,
api_name=PYANNOTE_API_NAME,
)
response = job.result(timeout=result_timeout_s)
call_end = time.perf_counter()
result_item: dict[str, Any] = {
"model": PYANNOTE_COMMUNITY_1,
"api_name": PYANNOTE_API_NAME,
"status": "ok",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"effective_model_options_json": options_json,
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"result": response,
}
except FutureTimeoutError:
call_end = time.perf_counter()
result_item = {
"model": PYANNOTE_COMMUNITY_1,
"api_name": PYANNOTE_API_NAME,
"status": "error",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"effective_model_options_json": options_json,
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"error": f"Pyannote call timed out after {result_timeout_s}s. Increase result_timeout_s for long audio.",
}
except Exception as exc:
call_end = time.perf_counter()
result_item = {
"model": PYANNOTE_COMMUNITY_1,
"api_name": PYANNOTE_API_NAME,
"status": "error",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"effective_model_options_json": options_json,
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"error": str(exc),
}
finished_at = time.perf_counter()
payload: dict[str, Any] = {
"space": space,
"audio_file": audio_file,
"task": "diarize",
"model_options_json": options_json,
"model_options_by_model": model_options_by_model,
"models": [PYANNOTE_COMMUNITY_1],
"benchmark_timing": {
"total_client_wall_clock_seconds": round(finished_at - started_at, 4),
},
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"results": [result_item],
"leaderboard_by_gpu_window_seconds": _leaderboard_rows([result_item]),
}
if save_outputs:
payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir)
return payload
def run_parakeet_then_pyannote(
space: str,
audio_file: str,
parakeet_model_options: str | dict[str, Any] | None = None,
pyannote_model_options: str | dict[str, Any] | None = None,
pyannote_hf_token: str | None = None,
hf_token: str | None = None,
request_timeout_s: float = 1800.0,
result_timeout_s: float | None = 7200.0,
save_outputs: bool = True,
output_dir: str | Path | None = None,
) -> dict[str, Any]:
"""Run Parakeet transcription then Pyannote diarization and aggregate ZeroGPU timing."""
client = Client(space, token=hf_token, httpx_kwargs={"timeout": request_timeout_s})
parakeet_model_options = _normalize_model_options_for_model(PARAKEET_V3, parakeet_model_options)
pyannote_model_options = _normalize_model_options_for_model(PYANNOTE_COMMUNITY_1, pyannote_model_options)
if pyannote_hf_token:
if pyannote_model_options is None:
pyannote_model_options = {}
elif isinstance(pyannote_model_options, str):
parsed = json.loads(pyannote_model_options)
if not isinstance(parsed, dict):
raise ValueError("pyannote_model_options string must decode to a JSON object.")
pyannote_model_options = parsed
if isinstance(pyannote_model_options, dict):
pyannote_model_options = dict(pyannote_model_options)
pyannote_model_options.setdefault("hf_token", pyannote_hf_token)
parakeet_options_json = _to_model_options_json(parakeet_model_options)
pyannote_options_json = _to_model_options_json(pyannote_model_options)
def _call_api(api_name: str, call_kwargs: dict[str, Any]) -> dict[str, Any]:
call_start = time.perf_counter()
try:
job = client.submit(api_name=api_name, **call_kwargs)
response = job.result(timeout=result_timeout_s)
call_end = time.perf_counter()
return {
"status": "ok",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"result": response,
}
except FutureTimeoutError:
call_end = time.perf_counter()
return {
"status": "error",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"error": f"Call timed out after {result_timeout_s}s.",
}
except Exception as exc:
call_end = time.perf_counter()
return {
"status": "error",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"error": str(exc),
}
started_at = time.perf_counter()
parakeet_call = _call_api(
api_name=PARAKEET_API_NAME,
call_kwargs={
"audio_file": handle_file(audio_file),
"task": "transcribe",
"language": None,
"initial_prompt": None,
"postprocess_prompt": None,
"model_options_json": parakeet_options_json,
},
)
parakeet_call.update(
{
"model": PARAKEET_V3,
"api_name": PARAKEET_API_NAME,
"effective_model_options_json": parakeet_options_json,
}
)
pyannote_call = _call_api(
api_name=PYANNOTE_API_NAME,
call_kwargs={
"audio_file": handle_file(audio_file),
"model_options_json": pyannote_options_json,
},
)
pyannote_call.update(
{
"model": PYANNOTE_COMMUNITY_1,
"api_name": PYANNOTE_API_NAME,
"effective_model_options_json": pyannote_options_json,
}
)
results: list[dict[str, Any]] = [parakeet_call, pyannote_call]
finished_at = time.perf_counter()
total_zerogpu_gpu_window_seconds = 0.0
total_zerogpu_inference_seconds = 0.0
for item in results:
if item.get("status") != "ok":
continue
timing = ((item.get("result") or {}).get("zerogpu_timing") or {})
total_zerogpu_gpu_window_seconds += float(timing.get("gpu_window_seconds", 0.0))
total_zerogpu_inference_seconds += float(timing.get("inference_seconds", 0.0))
payload: dict[str, Any] = {
"space": space,
"audio_file": audio_file,
"task": "transcribe+diarize",
"models": [PARAKEET_V3, PYANNOTE_COMMUNITY_1],
"model_options_by_model": {
PARAKEET_V3: parakeet_options_json,
PYANNOTE_COMMUNITY_1: pyannote_options_json,
},
"pyannote_hf_token_provided": bool(pyannote_hf_token),
"timeouts": {
"request_timeout_s": request_timeout_s,
"result_timeout_s": result_timeout_s,
},
"benchmark_timing": {
"total_client_wall_clock_seconds": round(finished_at - started_at, 4),
},
"total_zerogpu_timing": {
"gpu_window_seconds": round(total_zerogpu_gpu_window_seconds, 4),
"inference_seconds": round(total_zerogpu_inference_seconds, 4),
},
"results": results,
"leaderboard_by_gpu_window_seconds": _leaderboard_rows(results),
}
if save_outputs:
payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir)
return payload
def run_all_model_apis(
space: str,
audio_file: str,
task: str = "transcribe",
language: str | None = None,
initial_prompt: str | None = None,
postprocess_prompt: str | None = None,
model_options: str | dict[str, Any] | None = None,
model_options_by_model: dict[str, str | dict[str, Any]] | None = None,
models: list[str] | None = None,
hf_token: str | None = None,
save_outputs: bool = True,
output_dir: str | Path | None = None,
) -> dict[str, Any]:
"""Run each model-specific API endpoint one by one and collect full outputs.
Designed for use from IPython notebooks/scripts.
Use model_options_by_model for per-model tuning in a single benchmark run.
"""
if models is None:
model_sequence = list(MODEL_API_BY_LABEL.keys())
else:
invalid = [m for m in models if m not in MODEL_API_BY_LABEL]
if invalid:
raise ValueError(f"Unsupported models requested: {invalid}")
model_sequence = models
client = Client(space, token=hf_token)
options_json = _to_model_options_json(model_options)
started_at = time.perf_counter()
results: list[dict[str, Any]] = []
for model in model_sequence:
api_name = MODEL_API_BY_LABEL[model]
effective_options_source: str | dict[str, Any] | None = model_options
if model_options_by_model and model in model_options_by_model:
effective_options_source = model_options_by_model[model]
normalized_options = _normalize_model_options_for_model(model, effective_options_source)
effective_options_json = _to_model_options_json(normalized_options)
call_start = time.perf_counter()
try:
response = client.predict(
audio_file=handle_file(audio_file),
task=task,
language=language,
initial_prompt=initial_prompt,
postprocess_prompt=postprocess_prompt,
model_options_json=effective_options_json,
api_name=api_name,
)
call_end = time.perf_counter()
results.append(
{
"model": model,
"api_name": api_name,
"status": "ok",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"effective_model_options_json": effective_options_json,
"result": response,
}
)
except Exception as exc:
call_end = time.perf_counter()
results.append(
{
"model": model,
"api_name": api_name,
"status": "error",
"client_wall_clock_seconds": round(call_end - call_start, 4),
"effective_model_options_json": effective_options_json,
"error": str(exc),
}
)
finished_at = time.perf_counter()
payload = {
"space": space,
"audio_file": audio_file,
"task": task,
"language": language,
"initial_prompt": initial_prompt,
"postprocess_prompt": postprocess_prompt,
"model_options_json": options_json,
"model_options_by_model": model_options_by_model,
"models": model_sequence,
"benchmark_timing": {
"total_client_wall_clock_seconds": round(finished_at - started_at, 4),
},
"results": results,
"leaderboard_by_gpu_window_seconds": _leaderboard_rows(results),
}
if save_outputs:
payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir)
return payload