import json import time from concurrent.futures import TimeoutError as FutureTimeoutError from datetime import datetime from pathlib import Path from typing import Any from gradio_client import Client, handle_file from src.constants import ( PARAKEET_V3, PYANNOTE_COMMUNITY_1, WHISPER_LARGE_V3_TURBO, ) MODEL_API_BY_LABEL = { WHISPER_LARGE_V3_TURBO: "/transcribe_whisper_large_v3_turbo", PARAKEET_V3: "/transcribe_parakeet_v3", } PYANNOTE_API_NAME = "/diarize_pyannote_community_1" PARAKEET_API_NAME = "/transcribe_parakeet_v3" def _safe_name(value: str) -> str: return value.replace("/", "_").replace(" ", "_").replace("(", "").replace(")", "").lower() def _to_model_options_json(model_options: str | dict[str, Any] | None) -> str | None: if model_options is None: return None if isinstance(model_options, str): return model_options return json.dumps(model_options) def _normalize_model_options_for_model(model_label: str, model_options: str | dict[str, Any] | None) -> str | dict[str, Any] | None: if not isinstance(model_options, dict): return model_options normalized = dict(model_options) if model_label == PARAKEET_V3: normalized.pop("batch_size", None) return normalized def _leaderboard_rows(results: list[dict[str, Any]]) -> list[dict[str, Any]]: ok_items = [r for r in results if r.get("status") == "ok"] def key_fn(item: dict[str, Any]) -> float: payload = item.get("result") or {} timing = payload.get("zerogpu_timing") or {} value = timing.get("gpu_window_seconds") return float("inf") if value is None else float(value) ranked = sorted(ok_items, key=key_fn) return [ { "model": item["model"], "api_name": item["api_name"], "gpu_window_seconds": ((item.get("result") or {}).get("zerogpu_timing") or {}).get("gpu_window_seconds"), "inference_seconds": ((item.get("result") or {}).get("zerogpu_timing") or {}).get("inference_seconds"), "client_wall_clock_seconds": item.get("client_wall_clock_seconds"), } for item in ranked ] def _save_benchmark_outputs(result_obj: dict[str, Any], output_dir: str | Path | None) -> dict[str, Any]: root = Path(output_dir) if output_dir else Path("benchmark_outputs") run_dir = root / datetime.now().strftime("%Y%m%d_%H%M%S") run_dir.mkdir(parents=True, exist_ok=True) per_model_files: dict[str, str] = {} for item in result_obj.get("results", []): model = item.get("model", "unknown") filename = f"{_safe_name(model)}.json" file_path = run_dir / filename payload = item.get("result") if item.get("status") == "ok" else item file_path.write_text(json.dumps(payload, indent=2, ensure_ascii=False)) per_model_files[model] = str(file_path) summary = { "space": result_obj.get("space"), "audio_file": result_obj.get("audio_file"), "task": result_obj.get("task"), "language": result_obj.get("language"), "models": result_obj.get("models"), "benchmark_timing": result_obj.get("benchmark_timing"), "leaderboard_by_gpu_window_seconds": result_obj.get("leaderboard_by_gpu_window_seconds"), "summary": [ { "model": r.get("model"), "status": r.get("status"), "api_name": r.get("api_name"), "client_wall_clock_seconds": r.get("client_wall_clock_seconds"), "gpu_window_seconds": ((r.get("result") or {}).get("zerogpu_timing") or {}).get("gpu_window_seconds"), "inference_seconds": ((r.get("result") or {}).get("zerogpu_timing") or {}).get("inference_seconds"), "error": r.get("error"), } for r in result_obj.get("results", []) ], } stats_path = run_dir / "benchmark_stats.json" stats_path.write_text(json.dumps(summary, indent=2, ensure_ascii=False)) return { "run_dir": str(run_dir), "per_model_files": per_model_files, "benchmark_stats_file": str(stats_path), } def run_pyannote_api( space: str, audio_file: str, model_options_by_model: dict[str, str | dict[str, Any]] | None = None, model_options: str | dict[str, Any] | None = None, hf_token: str | None = None, request_timeout_s: float = 1800.0, result_timeout_s: float | None = 7200.0, save_outputs: bool = True, output_dir: str | Path | None = None, ) -> dict[str, Any]: """Benchmark-style wrapper for a single pyannote diarization API call.""" client = Client(space, token=hf_token, httpx_kwargs={"timeout": request_timeout_s}) options_json = _to_model_options_json(model_options) if model_options_by_model and PYANNOTE_COMMUNITY_1 in model_options_by_model: options_json = _to_model_options_json(model_options_by_model[PYANNOTE_COMMUNITY_1]) started_at = time.perf_counter() call_start = time.perf_counter() try: job = client.submit( audio_file=handle_file(audio_file), model_options_json=options_json, api_name=PYANNOTE_API_NAME, ) response = job.result(timeout=result_timeout_s) call_end = time.perf_counter() result_item: dict[str, Any] = { "model": PYANNOTE_COMMUNITY_1, "api_name": PYANNOTE_API_NAME, "status": "ok", "client_wall_clock_seconds": round(call_end - call_start, 4), "effective_model_options_json": options_json, "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "result": response, } except FutureTimeoutError: call_end = time.perf_counter() result_item = { "model": PYANNOTE_COMMUNITY_1, "api_name": PYANNOTE_API_NAME, "status": "error", "client_wall_clock_seconds": round(call_end - call_start, 4), "effective_model_options_json": options_json, "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "error": f"Pyannote call timed out after {result_timeout_s}s. Increase result_timeout_s for long audio.", } except Exception as exc: call_end = time.perf_counter() result_item = { "model": PYANNOTE_COMMUNITY_1, "api_name": PYANNOTE_API_NAME, "status": "error", "client_wall_clock_seconds": round(call_end - call_start, 4), "effective_model_options_json": options_json, "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "error": str(exc), } finished_at = time.perf_counter() payload: dict[str, Any] = { "space": space, "audio_file": audio_file, "task": "diarize", "model_options_json": options_json, "model_options_by_model": model_options_by_model, "models": [PYANNOTE_COMMUNITY_1], "benchmark_timing": { "total_client_wall_clock_seconds": round(finished_at - started_at, 4), }, "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "results": [result_item], "leaderboard_by_gpu_window_seconds": _leaderboard_rows([result_item]), } if save_outputs: payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir) return payload def run_parakeet_then_pyannote( space: str, audio_file: str, parakeet_model_options: str | dict[str, Any] | None = None, pyannote_model_options: str | dict[str, Any] | None = None, pyannote_hf_token: str | None = None, hf_token: str | None = None, request_timeout_s: float = 1800.0, result_timeout_s: float | None = 7200.0, save_outputs: bool = True, output_dir: str | Path | None = None, ) -> dict[str, Any]: """Run Parakeet transcription then Pyannote diarization and aggregate ZeroGPU timing.""" client = Client(space, token=hf_token, httpx_kwargs={"timeout": request_timeout_s}) parakeet_model_options = _normalize_model_options_for_model(PARAKEET_V3, parakeet_model_options) pyannote_model_options = _normalize_model_options_for_model(PYANNOTE_COMMUNITY_1, pyannote_model_options) if pyannote_hf_token: if pyannote_model_options is None: pyannote_model_options = {} elif isinstance(pyannote_model_options, str): parsed = json.loads(pyannote_model_options) if not isinstance(parsed, dict): raise ValueError("pyannote_model_options string must decode to a JSON object.") pyannote_model_options = parsed if isinstance(pyannote_model_options, dict): pyannote_model_options = dict(pyannote_model_options) pyannote_model_options.setdefault("hf_token", pyannote_hf_token) parakeet_options_json = _to_model_options_json(parakeet_model_options) pyannote_options_json = _to_model_options_json(pyannote_model_options) def _call_api(api_name: str, call_kwargs: dict[str, Any]) -> dict[str, Any]: call_start = time.perf_counter() try: job = client.submit(api_name=api_name, **call_kwargs) response = job.result(timeout=result_timeout_s) call_end = time.perf_counter() return { "status": "ok", "client_wall_clock_seconds": round(call_end - call_start, 4), "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "result": response, } except FutureTimeoutError: call_end = time.perf_counter() return { "status": "error", "client_wall_clock_seconds": round(call_end - call_start, 4), "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "error": f"Call timed out after {result_timeout_s}s.", } except Exception as exc: call_end = time.perf_counter() return { "status": "error", "client_wall_clock_seconds": round(call_end - call_start, 4), "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "error": str(exc), } started_at = time.perf_counter() parakeet_call = _call_api( api_name=PARAKEET_API_NAME, call_kwargs={ "audio_file": handle_file(audio_file), "task": "transcribe", "language": None, "initial_prompt": None, "postprocess_prompt": None, "model_options_json": parakeet_options_json, }, ) parakeet_call.update( { "model": PARAKEET_V3, "api_name": PARAKEET_API_NAME, "effective_model_options_json": parakeet_options_json, } ) pyannote_call = _call_api( api_name=PYANNOTE_API_NAME, call_kwargs={ "audio_file": handle_file(audio_file), "model_options_json": pyannote_options_json, }, ) pyannote_call.update( { "model": PYANNOTE_COMMUNITY_1, "api_name": PYANNOTE_API_NAME, "effective_model_options_json": pyannote_options_json, } ) results: list[dict[str, Any]] = [parakeet_call, pyannote_call] finished_at = time.perf_counter() total_zerogpu_gpu_window_seconds = 0.0 total_zerogpu_inference_seconds = 0.0 for item in results: if item.get("status") != "ok": continue timing = ((item.get("result") or {}).get("zerogpu_timing") or {}) total_zerogpu_gpu_window_seconds += float(timing.get("gpu_window_seconds", 0.0)) total_zerogpu_inference_seconds += float(timing.get("inference_seconds", 0.0)) payload: dict[str, Any] = { "space": space, "audio_file": audio_file, "task": "transcribe+diarize", "models": [PARAKEET_V3, PYANNOTE_COMMUNITY_1], "model_options_by_model": { PARAKEET_V3: parakeet_options_json, PYANNOTE_COMMUNITY_1: pyannote_options_json, }, "pyannote_hf_token_provided": bool(pyannote_hf_token), "timeouts": { "request_timeout_s": request_timeout_s, "result_timeout_s": result_timeout_s, }, "benchmark_timing": { "total_client_wall_clock_seconds": round(finished_at - started_at, 4), }, "total_zerogpu_timing": { "gpu_window_seconds": round(total_zerogpu_gpu_window_seconds, 4), "inference_seconds": round(total_zerogpu_inference_seconds, 4), }, "results": results, "leaderboard_by_gpu_window_seconds": _leaderboard_rows(results), } if save_outputs: payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir) return payload def run_all_model_apis( space: str, audio_file: str, task: str = "transcribe", language: str | None = None, initial_prompt: str | None = None, postprocess_prompt: str | None = None, model_options: str | dict[str, Any] | None = None, model_options_by_model: dict[str, str | dict[str, Any]] | None = None, models: list[str] | None = None, hf_token: str | None = None, save_outputs: bool = True, output_dir: str | Path | None = None, ) -> dict[str, Any]: """Run each model-specific API endpoint one by one and collect full outputs. Designed for use from IPython notebooks/scripts. Use model_options_by_model for per-model tuning in a single benchmark run. """ if models is None: model_sequence = list(MODEL_API_BY_LABEL.keys()) else: invalid = [m for m in models if m not in MODEL_API_BY_LABEL] if invalid: raise ValueError(f"Unsupported models requested: {invalid}") model_sequence = models client = Client(space, token=hf_token) options_json = _to_model_options_json(model_options) started_at = time.perf_counter() results: list[dict[str, Any]] = [] for model in model_sequence: api_name = MODEL_API_BY_LABEL[model] effective_options_source: str | dict[str, Any] | None = model_options if model_options_by_model and model in model_options_by_model: effective_options_source = model_options_by_model[model] normalized_options = _normalize_model_options_for_model(model, effective_options_source) effective_options_json = _to_model_options_json(normalized_options) call_start = time.perf_counter() try: response = client.predict( audio_file=handle_file(audio_file), task=task, language=language, initial_prompt=initial_prompt, postprocess_prompt=postprocess_prompt, model_options_json=effective_options_json, api_name=api_name, ) call_end = time.perf_counter() results.append( { "model": model, "api_name": api_name, "status": "ok", "client_wall_clock_seconds": round(call_end - call_start, 4), "effective_model_options_json": effective_options_json, "result": response, } ) except Exception as exc: call_end = time.perf_counter() results.append( { "model": model, "api_name": api_name, "status": "error", "client_wall_clock_seconds": round(call_end - call_start, 4), "effective_model_options_json": effective_options_json, "error": str(exc), } ) finished_at = time.perf_counter() payload = { "space": space, "audio_file": audio_file, "task": task, "language": language, "initial_prompt": initial_prompt, "postprocess_prompt": postprocess_prompt, "model_options_json": options_json, "model_options_by_model": model_options_by_model, "models": model_sequence, "benchmark_timing": { "total_client_wall_clock_seconds": round(finished_at - started_at, 4), }, "results": results, "leaderboard_by_gpu_window_seconds": _leaderboard_rows(results), } if save_outputs: payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir) return payload