Spaces:
Running on Zero
Running on Zero
| import json | |
| import time | |
| from concurrent.futures import TimeoutError as FutureTimeoutError | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| from gradio_client import Client, handle_file | |
| from src.constants import ( | |
| PARAKEET_V3, | |
| PYANNOTE_COMMUNITY_1, | |
| WHISPER_LARGE_V3_TURBO, | |
| ) | |
| MODEL_API_BY_LABEL = { | |
| WHISPER_LARGE_V3_TURBO: "/transcribe_whisper_large_v3_turbo", | |
| PARAKEET_V3: "/transcribe_parakeet_v3", | |
| } | |
| PYANNOTE_API_NAME = "/diarize_pyannote_community_1" | |
| PARAKEET_API_NAME = "/transcribe_parakeet_v3" | |
| def _safe_name(value: str) -> str: | |
| return value.replace("/", "_").replace(" ", "_").replace("(", "").replace(")", "").lower() | |
| def _to_model_options_json(model_options: str | dict[str, Any] | None) -> str | None: | |
| if model_options is None: | |
| return None | |
| if isinstance(model_options, str): | |
| return model_options | |
| return json.dumps(model_options) | |
| def _normalize_model_options_for_model(model_label: str, model_options: str | dict[str, Any] | None) -> str | dict[str, Any] | None: | |
| if not isinstance(model_options, dict): | |
| return model_options | |
| normalized = dict(model_options) | |
| if model_label == PARAKEET_V3: | |
| normalized.pop("batch_size", None) | |
| return normalized | |
| def _leaderboard_rows(results: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| ok_items = [r for r in results if r.get("status") == "ok"] | |
| def key_fn(item: dict[str, Any]) -> float: | |
| payload = item.get("result") or {} | |
| timing = payload.get("zerogpu_timing") or {} | |
| value = timing.get("gpu_window_seconds") | |
| return float("inf") if value is None else float(value) | |
| ranked = sorted(ok_items, key=key_fn) | |
| return [ | |
| { | |
| "model": item["model"], | |
| "api_name": item["api_name"], | |
| "gpu_window_seconds": ((item.get("result") or {}).get("zerogpu_timing") or {}).get("gpu_window_seconds"), | |
| "inference_seconds": ((item.get("result") or {}).get("zerogpu_timing") or {}).get("inference_seconds"), | |
| "client_wall_clock_seconds": item.get("client_wall_clock_seconds"), | |
| } | |
| for item in ranked | |
| ] | |
| def _save_benchmark_outputs(result_obj: dict[str, Any], output_dir: str | Path | None) -> dict[str, Any]: | |
| root = Path(output_dir) if output_dir else Path("benchmark_outputs") | |
| run_dir = root / datetime.now().strftime("%Y%m%d_%H%M%S") | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| per_model_files: dict[str, str] = {} | |
| for item in result_obj.get("results", []): | |
| model = item.get("model", "unknown") | |
| filename = f"{_safe_name(model)}.json" | |
| file_path = run_dir / filename | |
| payload = item.get("result") if item.get("status") == "ok" else item | |
| file_path.write_text(json.dumps(payload, indent=2, ensure_ascii=False)) | |
| per_model_files[model] = str(file_path) | |
| summary = { | |
| "space": result_obj.get("space"), | |
| "audio_file": result_obj.get("audio_file"), | |
| "task": result_obj.get("task"), | |
| "language": result_obj.get("language"), | |
| "models": result_obj.get("models"), | |
| "benchmark_timing": result_obj.get("benchmark_timing"), | |
| "leaderboard_by_gpu_window_seconds": result_obj.get("leaderboard_by_gpu_window_seconds"), | |
| "summary": [ | |
| { | |
| "model": r.get("model"), | |
| "status": r.get("status"), | |
| "api_name": r.get("api_name"), | |
| "client_wall_clock_seconds": r.get("client_wall_clock_seconds"), | |
| "gpu_window_seconds": ((r.get("result") or {}).get("zerogpu_timing") or {}).get("gpu_window_seconds"), | |
| "inference_seconds": ((r.get("result") or {}).get("zerogpu_timing") or {}).get("inference_seconds"), | |
| "error": r.get("error"), | |
| } | |
| for r in result_obj.get("results", []) | |
| ], | |
| } | |
| stats_path = run_dir / "benchmark_stats.json" | |
| stats_path.write_text(json.dumps(summary, indent=2, ensure_ascii=False)) | |
| return { | |
| "run_dir": str(run_dir), | |
| "per_model_files": per_model_files, | |
| "benchmark_stats_file": str(stats_path), | |
| } | |
| def run_pyannote_api( | |
| space: str, | |
| audio_file: str, | |
| model_options_by_model: dict[str, str | dict[str, Any]] | None = None, | |
| model_options: str | dict[str, Any] | None = None, | |
| hf_token: str | None = None, | |
| request_timeout_s: float = 1800.0, | |
| result_timeout_s: float | None = 7200.0, | |
| save_outputs: bool = True, | |
| output_dir: str | Path | None = None, | |
| ) -> dict[str, Any]: | |
| """Benchmark-style wrapper for a single pyannote diarization API call.""" | |
| client = Client(space, token=hf_token, httpx_kwargs={"timeout": request_timeout_s}) | |
| options_json = _to_model_options_json(model_options) | |
| if model_options_by_model and PYANNOTE_COMMUNITY_1 in model_options_by_model: | |
| options_json = _to_model_options_json(model_options_by_model[PYANNOTE_COMMUNITY_1]) | |
| started_at = time.perf_counter() | |
| call_start = time.perf_counter() | |
| try: | |
| job = client.submit( | |
| audio_file=handle_file(audio_file), | |
| model_options_json=options_json, | |
| api_name=PYANNOTE_API_NAME, | |
| ) | |
| response = job.result(timeout=result_timeout_s) | |
| call_end = time.perf_counter() | |
| result_item: dict[str, Any] = { | |
| "model": PYANNOTE_COMMUNITY_1, | |
| "api_name": PYANNOTE_API_NAME, | |
| "status": "ok", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "effective_model_options_json": options_json, | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "result": response, | |
| } | |
| except FutureTimeoutError: | |
| call_end = time.perf_counter() | |
| result_item = { | |
| "model": PYANNOTE_COMMUNITY_1, | |
| "api_name": PYANNOTE_API_NAME, | |
| "status": "error", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "effective_model_options_json": options_json, | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "error": f"Pyannote call timed out after {result_timeout_s}s. Increase result_timeout_s for long audio.", | |
| } | |
| except Exception as exc: | |
| call_end = time.perf_counter() | |
| result_item = { | |
| "model": PYANNOTE_COMMUNITY_1, | |
| "api_name": PYANNOTE_API_NAME, | |
| "status": "error", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "effective_model_options_json": options_json, | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "error": str(exc), | |
| } | |
| finished_at = time.perf_counter() | |
| payload: dict[str, Any] = { | |
| "space": space, | |
| "audio_file": audio_file, | |
| "task": "diarize", | |
| "model_options_json": options_json, | |
| "model_options_by_model": model_options_by_model, | |
| "models": [PYANNOTE_COMMUNITY_1], | |
| "benchmark_timing": { | |
| "total_client_wall_clock_seconds": round(finished_at - started_at, 4), | |
| }, | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "results": [result_item], | |
| "leaderboard_by_gpu_window_seconds": _leaderboard_rows([result_item]), | |
| } | |
| if save_outputs: | |
| payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir) | |
| return payload | |
| def run_parakeet_then_pyannote( | |
| space: str, | |
| audio_file: str, | |
| parakeet_model_options: str | dict[str, Any] | None = None, | |
| pyannote_model_options: str | dict[str, Any] | None = None, | |
| pyannote_hf_token: str | None = None, | |
| hf_token: str | None = None, | |
| request_timeout_s: float = 1800.0, | |
| result_timeout_s: float | None = 7200.0, | |
| save_outputs: bool = True, | |
| output_dir: str | Path | None = None, | |
| ) -> dict[str, Any]: | |
| """Run Parakeet transcription then Pyannote diarization and aggregate ZeroGPU timing.""" | |
| client = Client(space, token=hf_token, httpx_kwargs={"timeout": request_timeout_s}) | |
| parakeet_model_options = _normalize_model_options_for_model(PARAKEET_V3, parakeet_model_options) | |
| pyannote_model_options = _normalize_model_options_for_model(PYANNOTE_COMMUNITY_1, pyannote_model_options) | |
| if pyannote_hf_token: | |
| if pyannote_model_options is None: | |
| pyannote_model_options = {} | |
| elif isinstance(pyannote_model_options, str): | |
| parsed = json.loads(pyannote_model_options) | |
| if not isinstance(parsed, dict): | |
| raise ValueError("pyannote_model_options string must decode to a JSON object.") | |
| pyannote_model_options = parsed | |
| if isinstance(pyannote_model_options, dict): | |
| pyannote_model_options = dict(pyannote_model_options) | |
| pyannote_model_options.setdefault("hf_token", pyannote_hf_token) | |
| parakeet_options_json = _to_model_options_json(parakeet_model_options) | |
| pyannote_options_json = _to_model_options_json(pyannote_model_options) | |
| def _call_api(api_name: str, call_kwargs: dict[str, Any]) -> dict[str, Any]: | |
| call_start = time.perf_counter() | |
| try: | |
| job = client.submit(api_name=api_name, **call_kwargs) | |
| response = job.result(timeout=result_timeout_s) | |
| call_end = time.perf_counter() | |
| return { | |
| "status": "ok", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "result": response, | |
| } | |
| except FutureTimeoutError: | |
| call_end = time.perf_counter() | |
| return { | |
| "status": "error", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "error": f"Call timed out after {result_timeout_s}s.", | |
| } | |
| except Exception as exc: | |
| call_end = time.perf_counter() | |
| return { | |
| "status": "error", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "error": str(exc), | |
| } | |
| started_at = time.perf_counter() | |
| parakeet_call = _call_api( | |
| api_name=PARAKEET_API_NAME, | |
| call_kwargs={ | |
| "audio_file": handle_file(audio_file), | |
| "task": "transcribe", | |
| "language": None, | |
| "initial_prompt": None, | |
| "postprocess_prompt": None, | |
| "model_options_json": parakeet_options_json, | |
| }, | |
| ) | |
| parakeet_call.update( | |
| { | |
| "model": PARAKEET_V3, | |
| "api_name": PARAKEET_API_NAME, | |
| "effective_model_options_json": parakeet_options_json, | |
| } | |
| ) | |
| pyannote_call = _call_api( | |
| api_name=PYANNOTE_API_NAME, | |
| call_kwargs={ | |
| "audio_file": handle_file(audio_file), | |
| "model_options_json": pyannote_options_json, | |
| }, | |
| ) | |
| pyannote_call.update( | |
| { | |
| "model": PYANNOTE_COMMUNITY_1, | |
| "api_name": PYANNOTE_API_NAME, | |
| "effective_model_options_json": pyannote_options_json, | |
| } | |
| ) | |
| results: list[dict[str, Any]] = [parakeet_call, pyannote_call] | |
| finished_at = time.perf_counter() | |
| total_zerogpu_gpu_window_seconds = 0.0 | |
| total_zerogpu_inference_seconds = 0.0 | |
| for item in results: | |
| if item.get("status") != "ok": | |
| continue | |
| timing = ((item.get("result") or {}).get("zerogpu_timing") or {}) | |
| total_zerogpu_gpu_window_seconds += float(timing.get("gpu_window_seconds", 0.0)) | |
| total_zerogpu_inference_seconds += float(timing.get("inference_seconds", 0.0)) | |
| payload: dict[str, Any] = { | |
| "space": space, | |
| "audio_file": audio_file, | |
| "task": "transcribe+diarize", | |
| "models": [PARAKEET_V3, PYANNOTE_COMMUNITY_1], | |
| "model_options_by_model": { | |
| PARAKEET_V3: parakeet_options_json, | |
| PYANNOTE_COMMUNITY_1: pyannote_options_json, | |
| }, | |
| "pyannote_hf_token_provided": bool(pyannote_hf_token), | |
| "timeouts": { | |
| "request_timeout_s": request_timeout_s, | |
| "result_timeout_s": result_timeout_s, | |
| }, | |
| "benchmark_timing": { | |
| "total_client_wall_clock_seconds": round(finished_at - started_at, 4), | |
| }, | |
| "total_zerogpu_timing": { | |
| "gpu_window_seconds": round(total_zerogpu_gpu_window_seconds, 4), | |
| "inference_seconds": round(total_zerogpu_inference_seconds, 4), | |
| }, | |
| "results": results, | |
| "leaderboard_by_gpu_window_seconds": _leaderboard_rows(results), | |
| } | |
| if save_outputs: | |
| payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir) | |
| return payload | |
| def run_all_model_apis( | |
| space: str, | |
| audio_file: str, | |
| task: str = "transcribe", | |
| language: str | None = None, | |
| initial_prompt: str | None = None, | |
| postprocess_prompt: str | None = None, | |
| model_options: str | dict[str, Any] | None = None, | |
| model_options_by_model: dict[str, str | dict[str, Any]] | None = None, | |
| models: list[str] | None = None, | |
| hf_token: str | None = None, | |
| save_outputs: bool = True, | |
| output_dir: str | Path | None = None, | |
| ) -> dict[str, Any]: | |
| """Run each model-specific API endpoint one by one and collect full outputs. | |
| Designed for use from IPython notebooks/scripts. | |
| Use model_options_by_model for per-model tuning in a single benchmark run. | |
| """ | |
| if models is None: | |
| model_sequence = list(MODEL_API_BY_LABEL.keys()) | |
| else: | |
| invalid = [m for m in models if m not in MODEL_API_BY_LABEL] | |
| if invalid: | |
| raise ValueError(f"Unsupported models requested: {invalid}") | |
| model_sequence = models | |
| client = Client(space, token=hf_token) | |
| options_json = _to_model_options_json(model_options) | |
| started_at = time.perf_counter() | |
| results: list[dict[str, Any]] = [] | |
| for model in model_sequence: | |
| api_name = MODEL_API_BY_LABEL[model] | |
| effective_options_source: str | dict[str, Any] | None = model_options | |
| if model_options_by_model and model in model_options_by_model: | |
| effective_options_source = model_options_by_model[model] | |
| normalized_options = _normalize_model_options_for_model(model, effective_options_source) | |
| effective_options_json = _to_model_options_json(normalized_options) | |
| call_start = time.perf_counter() | |
| try: | |
| response = client.predict( | |
| audio_file=handle_file(audio_file), | |
| task=task, | |
| language=language, | |
| initial_prompt=initial_prompt, | |
| postprocess_prompt=postprocess_prompt, | |
| model_options_json=effective_options_json, | |
| api_name=api_name, | |
| ) | |
| call_end = time.perf_counter() | |
| results.append( | |
| { | |
| "model": model, | |
| "api_name": api_name, | |
| "status": "ok", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "effective_model_options_json": effective_options_json, | |
| "result": response, | |
| } | |
| ) | |
| except Exception as exc: | |
| call_end = time.perf_counter() | |
| results.append( | |
| { | |
| "model": model, | |
| "api_name": api_name, | |
| "status": "error", | |
| "client_wall_clock_seconds": round(call_end - call_start, 4), | |
| "effective_model_options_json": effective_options_json, | |
| "error": str(exc), | |
| } | |
| ) | |
| finished_at = time.perf_counter() | |
| payload = { | |
| "space": space, | |
| "audio_file": audio_file, | |
| "task": task, | |
| "language": language, | |
| "initial_prompt": initial_prompt, | |
| "postprocess_prompt": postprocess_prompt, | |
| "model_options_json": options_json, | |
| "model_options_by_model": model_options_by_model, | |
| "models": model_sequence, | |
| "benchmark_timing": { | |
| "total_client_wall_clock_seconds": round(finished_at - started_at, 4), | |
| }, | |
| "results": results, | |
| "leaderboard_by_gpu_window_seconds": _leaderboard_rows(results), | |
| } | |
| if save_outputs: | |
| payload["saved_outputs"] = _save_benchmark_outputs(payload, output_dir=output_dir) | |
| return payload | |