Buckets:
| import asyncio | |
| import os | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| from statistics import mean | |
| from statistics import median | |
| import subprocess | |
| import sys | |
| import time | |
| from typing import Any | |
| import torch | |
| from loguru import logger | |
| def parse_args() -> ArgumentParser: | |
| parser = ArgumentParser(description="Benchmark Khaya TTS backends on GPU.") | |
| parser.add_argument( | |
| "--backends", | |
| nargs="+", | |
| default=["eager", "inductor", "tensorrt"], | |
| help="Backends to benchmark in sequence.", | |
| ) | |
| parser.add_argument( | |
| "--iterations", | |
| type=int, | |
| default=10, | |
| help="Number of timed iterations per backend.", | |
| ) | |
| parser.add_argument( | |
| "--warmup", | |
| type=int, | |
| default=2, | |
| help="Number of warmup requests per backend.", | |
| ) | |
| parser.add_argument( | |
| "--text", | |
| default="Wo ho te sen? This is a benchmark request for Khaya TTS.", | |
| help="Benchmark text payload.", | |
| ) | |
| parser.add_argument( | |
| "--language", | |
| default=None, | |
| help="Optional model key to benchmark. Defaults to the first loaded model.", | |
| ) | |
| parser.add_argument( | |
| "--speaker-id", | |
| default=None, | |
| help="Optional speaker id. Defaults to the first configured speaker.", | |
| ) | |
| parser.add_argument( | |
| "--environment", | |
| default=os.getenv("ENVIRONMENT", "production"), | |
| choices=["test", "production"], | |
| help="Settings environment to use while loading the service.", | |
| ) | |
| parser.add_argument( | |
| "--compile-vocoder", | |
| action="store_true", | |
| help="Also compile the vocoder for non-eager backends.", | |
| ) | |
| parser.add_argument( | |
| "--use-ray-worker", | |
| action="store_true", | |
| help="Run the benchmark in a Ray task that requests a GPU worker.", | |
| ) | |
| parser.add_argument( | |
| "--ray-address", | |
| default=os.getenv("RAY_ADDRESS", "auto"), | |
| help="Ray cluster address to connect to when using --use-ray-worker.", | |
| ) | |
| return parser | |
| def percentile(values: list[float], fraction: float) -> float: | |
| if not values: | |
| return 0.0 | |
| ordered = sorted(values) | |
| index = min(len(ordered) - 1, max(0, round((len(ordered) - 1) * fraction))) | |
| return ordered[index] | |
| def namespace_to_config(args: Any) -> dict[str, Any]: | |
| return { | |
| "backends": [backend.lower() for backend in args.backends], | |
| "iterations": args.iterations, | |
| "warmup": args.warmup, | |
| "text": args.text, | |
| "language": args.language, | |
| "speaker_id": args.speaker_id, | |
| "environment": args.environment, | |
| "compile_vocoder": args.compile_vocoder, | |
| "ray_address": args.ray_address, | |
| } | |
| def ensure_repo_package_installed() -> None: | |
| repo_root = Path(__file__).resolve().parent | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", "-e", str(repo_root)], | |
| check=True, | |
| ) | |
| async def benchmark_backend( | |
| backend: str, | |
| config: dict[str, Any], | |
| device: torch.device, | |
| ) -> dict[str, float | str]: | |
| from tts_backend.adapters.fastapi.v0.requests import SpeechRequest | |
| from tts_backend.application.use_cases.tts import get_speaker_repository | |
| from tts_backend.application.use_cases.tts import handle_tts_request | |
| from tts_backend.application.use_cases.tts import TTSService | |
| from tts_backend.infrastructure.config.settings import Settings | |
| os.environ["TTS_ACCELERATION_BACKEND"] = backend | |
| os.environ["TTS_COMPILE_VOCODER"] = ( | |
| "true" if config["compile_vocoder"] else "false" | |
| ) | |
| start_load = time.perf_counter() | |
| settings = Settings(environment=config["environment"]) | |
| speaker_repo = get_speaker_repository(settings, device) | |
| if hasattr(speaker_repo, "pre_cache_all"): | |
| await speaker_repo.pre_cache_all() | |
| service = TTSService(settings=settings, device=device, speaker_repository=speaker_repo) | |
| load_time = time.perf_counter() - start_load | |
| langs = list(service.models.keys()) | |
| if not langs: | |
| raise RuntimeError("No models were loaded for benchmarking.") | |
| lang = config["language"] or langs[0] | |
| speakers = service.get_speakers(lang) | |
| speaker_id = config["speaker_id"] or ( | |
| speakers[0] if isinstance(speakers, list) and speakers else None | |
| ) | |
| if speaker_id is None: | |
| raise RuntimeError(f"No speakers found for language '{lang}'.") | |
| request = SpeechRequest( | |
| text=config["text"], | |
| language=lang, | |
| speaker_id=speaker_id, | |
| stream=False, | |
| ) | |
| logger.info( | |
| f"Benchmarking backend='{backend}' on language='{lang}' speaker='{speaker_id}' " | |
| f"(load={load_time:.2f}s)" | |
| ) | |
| for warmup_idx in range(config["warmup"]): | |
| warmup_start = time.perf_counter() | |
| await handle_tts_request(request, service) | |
| logger.info( | |
| f"[{backend}] warmup {warmup_idx + 1}/{config['warmup']} " | |
| f"took {time.perf_counter() - warmup_start:.3f}s" | |
| ) | |
| synth_model = service.models[lang] | |
| model_latencies: list[float] = [] | |
| end_to_end_latencies: list[float] = [] | |
| audio_duration_sec = 0.0 | |
| for idx in range(config["iterations"]): | |
| model_start = time.perf_counter() | |
| audio, sample_rate = synth_model.run_tts( | |
| config["text"], lang=lang, speaker_id=speaker_id | |
| ) | |
| model_latency = time.perf_counter() - model_start | |
| model_latencies.append(model_latency) | |
| if idx == 0: | |
| audio_duration_sec = len(audio) / sample_rate | |
| logger.info( | |
| f"[{backend}] model iter {idx + 1}/{config['iterations']}: " | |
| f"{model_latency:.3f}s" | |
| ) | |
| del audio | |
| for idx in range(config["iterations"]): | |
| end_to_end_start = time.perf_counter() | |
| result = await handle_tts_request(request, service) | |
| end_to_end_latency = time.perf_counter() - end_to_end_start | |
| end_to_end_latencies.append(end_to_end_latency) | |
| logger.info( | |
| f"[{backend}] service iter {idx + 1}/{config['iterations']}: " | |
| f"{end_to_end_latency:.3f}s" | |
| ) | |
| del result | |
| return { | |
| "backend": backend, | |
| "active_backend": getattr(synth_model, "active_backend", backend), | |
| "load_time_s": load_time, | |
| "model_avg_s": mean(model_latencies), | |
| "model_p50_s": median(model_latencies), | |
| "model_p95_s": percentile(model_latencies, 0.95), | |
| "service_avg_s": mean(end_to_end_latencies), | |
| "service_p50_s": median(end_to_end_latencies), | |
| "service_p95_s": percentile(end_to_end_latencies, 0.95), | |
| "audio_duration_s": audio_duration_sec, | |
| "model_rtf": audio_duration_sec / mean(model_latencies), | |
| "service_rtf": audio_duration_sec / mean(end_to_end_latencies), | |
| } | |
| async def run_benchmark_suite(config: dict[str, Any]) -> list[dict[str, float | str]]: | |
| logger.info(f"CUDA Available: {torch.cuda.is_available()}") | |
| logger.info(f"PyTorch Version: {torch.__version__}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU: {torch.cuda.get_device_name(0)}") | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| try: | |
| import torch_tensorrt | |
| logger.info(f"Torch-TensorRT Version: {torch_tensorrt.__version__}") | |
| except ImportError: | |
| logger.warning("torch_tensorrt is not installed in the environment.") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| summaries = [] | |
| for backend in config["backends"]: | |
| summaries.append(await benchmark_backend(backend, config, device)) | |
| return summaries | |
| def log_summary(summaries: list[dict[str, float | str]]) -> None: | |
| baseline = next((item for item in summaries if item["backend"] == "eager"), None) | |
| logger.info("=== Benchmark Summary ===") | |
| for summary in summaries: | |
| speedup = ( | |
| baseline["service_avg_s"] / summary["service_avg_s"] | |
| if baseline is not None | |
| else 1.0 | |
| ) | |
| logger.info( | |
| f"{summary['backend']} -> active={summary['active_backend']}, " | |
| f"service_avg={summary['service_avg_s']:.3f}s, " | |
| f"service_p95={summary['service_p95_s']:.3f}s, " | |
| f"service_rtf={summary['service_rtf']:.2f}x, " | |
| f"speedup_vs_eager={speedup:.2f}x" | |
| ) | |
| def run_benchmark_via_ray(config: dict[str, Any]) -> list[dict[str, float | str]]: | |
| import ray | |
| def remote_benchmark(remote_config: dict[str, Any]) -> list[dict[str, float | str]]: | |
| ensure_repo_package_installed() | |
| return asyncio.run(run_benchmark_suite(remote_config)) | |
| ray.init(address=config["ray_address"], log_to_driver=True) | |
| try: | |
| return ray.get(remote_benchmark.remote(config)) | |
| finally: | |
| ray.shutdown() | |
| def main() -> None: | |
| args = parse_args().parse_args() | |
| config = namespace_to_config(args) | |
| if args.use_ray_worker: | |
| summaries = run_benchmark_via_ray(config) | |
| else: | |
| summaries = asyncio.run(run_benchmark_suite(config)) | |
| log_summary(summaries) | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 9.3 kB
- Xet hash:
- 0a7e8c867be512da4f64a6ed21bc90abf4243e7c6475f8cb061d32a4bd91e424
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.