j0eBee/khaya-tts-benchmark / test_tensorrt.py
j0eBee's picture
download
raw
9.3 kB
import asyncio
import os
from argparse import ArgumentParser
from pathlib import Path
from statistics import mean
from statistics import median
import subprocess
import sys
import time
from typing import Any
import torch
from loguru import logger
def parse_args() -> ArgumentParser:
parser = ArgumentParser(description="Benchmark Khaya TTS backends on GPU.")
parser.add_argument(
"--backends",
nargs="+",
default=["eager", "inductor", "tensorrt"],
help="Backends to benchmark in sequence.",
)
parser.add_argument(
"--iterations",
type=int,
default=10,
help="Number of timed iterations per backend.",
)
parser.add_argument(
"--warmup",
type=int,
default=2,
help="Number of warmup requests per backend.",
)
parser.add_argument(
"--text",
default="Wo ho te sen? This is a benchmark request for Khaya TTS.",
help="Benchmark text payload.",
)
parser.add_argument(
"--language",
default=None,
help="Optional model key to benchmark. Defaults to the first loaded model.",
)
parser.add_argument(
"--speaker-id",
default=None,
help="Optional speaker id. Defaults to the first configured speaker.",
)
parser.add_argument(
"--environment",
default=os.getenv("ENVIRONMENT", "production"),
choices=["test", "production"],
help="Settings environment to use while loading the service.",
)
parser.add_argument(
"--compile-vocoder",
action="store_true",
help="Also compile the vocoder for non-eager backends.",
)
parser.add_argument(
"--use-ray-worker",
action="store_true",
help="Run the benchmark in a Ray task that requests a GPU worker.",
)
parser.add_argument(
"--ray-address",
default=os.getenv("RAY_ADDRESS", "auto"),
help="Ray cluster address to connect to when using --use-ray-worker.",
)
return parser
def percentile(values: list[float], fraction: float) -> float:
if not values:
return 0.0
ordered = sorted(values)
index = min(len(ordered) - 1, max(0, round((len(ordered) - 1) * fraction)))
return ordered[index]
def namespace_to_config(args: Any) -> dict[str, Any]:
return {
"backends": [backend.lower() for backend in args.backends],
"iterations": args.iterations,
"warmup": args.warmup,
"text": args.text,
"language": args.language,
"speaker_id": args.speaker_id,
"environment": args.environment,
"compile_vocoder": args.compile_vocoder,
"ray_address": args.ray_address,
}
def ensure_repo_package_installed() -> None:
repo_root = Path(__file__).resolve().parent
subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", "-e", str(repo_root)],
check=True,
)
async def benchmark_backend(
backend: str,
config: dict[str, Any],
device: torch.device,
) -> dict[str, float | str]:
from tts_backend.adapters.fastapi.v0.requests import SpeechRequest
from tts_backend.application.use_cases.tts import get_speaker_repository
from tts_backend.application.use_cases.tts import handle_tts_request
from tts_backend.application.use_cases.tts import TTSService
from tts_backend.infrastructure.config.settings import Settings
os.environ["TTS_ACCELERATION_BACKEND"] = backend
os.environ["TTS_COMPILE_VOCODER"] = (
"true" if config["compile_vocoder"] else "false"
)
start_load = time.perf_counter()
settings = Settings(environment=config["environment"])
speaker_repo = get_speaker_repository(settings, device)
if hasattr(speaker_repo, "pre_cache_all"):
await speaker_repo.pre_cache_all()
service = TTSService(settings=settings, device=device, speaker_repository=speaker_repo)
load_time = time.perf_counter() - start_load
langs = list(service.models.keys())
if not langs:
raise RuntimeError("No models were loaded for benchmarking.")
lang = config["language"] or langs[0]
speakers = service.get_speakers(lang)
speaker_id = config["speaker_id"] or (
speakers[0] if isinstance(speakers, list) and speakers else None
)
if speaker_id is None:
raise RuntimeError(f"No speakers found for language '{lang}'.")
request = SpeechRequest(
text=config["text"],
language=lang,
speaker_id=speaker_id,
stream=False,
)
logger.info(
f"Benchmarking backend='{backend}' on language='{lang}' speaker='{speaker_id}' "
f"(load={load_time:.2f}s)"
)
for warmup_idx in range(config["warmup"]):
warmup_start = time.perf_counter()
await handle_tts_request(request, service)
logger.info(
f"[{backend}] warmup {warmup_idx + 1}/{config['warmup']} "
f"took {time.perf_counter() - warmup_start:.3f}s"
)
synth_model = service.models[lang]
model_latencies: list[float] = []
end_to_end_latencies: list[float] = []
audio_duration_sec = 0.0
for idx in range(config["iterations"]):
model_start = time.perf_counter()
audio, sample_rate = synth_model.run_tts(
config["text"], lang=lang, speaker_id=speaker_id
)
model_latency = time.perf_counter() - model_start
model_latencies.append(model_latency)
if idx == 0:
audio_duration_sec = len(audio) / sample_rate
logger.info(
f"[{backend}] model iter {idx + 1}/{config['iterations']}: "
f"{model_latency:.3f}s"
)
del audio
for idx in range(config["iterations"]):
end_to_end_start = time.perf_counter()
result = await handle_tts_request(request, service)
end_to_end_latency = time.perf_counter() - end_to_end_start
end_to_end_latencies.append(end_to_end_latency)
logger.info(
f"[{backend}] service iter {idx + 1}/{config['iterations']}: "
f"{end_to_end_latency:.3f}s"
)
del result
return {
"backend": backend,
"active_backend": getattr(synth_model, "active_backend", backend),
"load_time_s": load_time,
"model_avg_s": mean(model_latencies),
"model_p50_s": median(model_latencies),
"model_p95_s": percentile(model_latencies, 0.95),
"service_avg_s": mean(end_to_end_latencies),
"service_p50_s": median(end_to_end_latencies),
"service_p95_s": percentile(end_to_end_latencies, 0.95),
"audio_duration_s": audio_duration_sec,
"model_rtf": audio_duration_sec / mean(model_latencies),
"service_rtf": audio_duration_sec / mean(end_to_end_latencies),
}
async def run_benchmark_suite(config: dict[str, Any]) -> list[dict[str, float | str]]:
logger.info(f"CUDA Available: {torch.cuda.is_available()}")
logger.info(f"PyTorch Version: {torch.__version__}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
import torch_tensorrt
logger.info(f"Torch-TensorRT Version: {torch_tensorrt.__version__}")
except ImportError:
logger.warning("torch_tensorrt is not installed in the environment.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summaries = []
for backend in config["backends"]:
summaries.append(await benchmark_backend(backend, config, device))
return summaries
def log_summary(summaries: list[dict[str, float | str]]) -> None:
baseline = next((item for item in summaries if item["backend"] == "eager"), None)
logger.info("=== Benchmark Summary ===")
for summary in summaries:
speedup = (
baseline["service_avg_s"] / summary["service_avg_s"]
if baseline is not None
else 1.0
)
logger.info(
f"{summary['backend']} -> active={summary['active_backend']}, "
f"service_avg={summary['service_avg_s']:.3f}s, "
f"service_p95={summary['service_p95_s']:.3f}s, "
f"service_rtf={summary['service_rtf']:.2f}x, "
f"speedup_vs_eager={speedup:.2f}x"
)
def run_benchmark_via_ray(config: dict[str, Any]) -> list[dict[str, float | str]]:
import ray
@ray.remote(num_cpus=1, num_gpus=1)
def remote_benchmark(remote_config: dict[str, Any]) -> list[dict[str, float | str]]:
ensure_repo_package_installed()
return asyncio.run(run_benchmark_suite(remote_config))
ray.init(address=config["ray_address"], log_to_driver=True)
try:
return ray.get(remote_benchmark.remote(config))
finally:
ray.shutdown()
def main() -> None:
args = parse_args().parse_args()
config = namespace_to_config(args)
if args.use_ray_worker:
summaries = run_benchmark_via_ray(config)
else:
summaries = asyncio.run(run_benchmark_suite(config))
log_summary(summaries)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
9.3 kB
·
Xet hash:
0a7e8c867be512da4f64a6ed21bc90abf4243e7c6475f8cb061d32a4bd91e424

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.