""" Batch inference for TTS evaluation benchmarks (Seed-TTS-eval & CV3-eval). Uses the llama.cpp backend (MOSS-TTS-Delay). Expected benchmark layout (per case):: {benchmark_dir}/{task}/{case_id}/prompt.wav {benchmark_dir}/{task}/{case_id}/label.txt Output layout:: {result_dir}/{task}/{case_id}/pred.wav Usage:: python scripts/batch_eval_llama_cpp.py \\ --config configs/llama_cpp/default.yaml \\ --benchmark-dir /path/to/eval/tts \\ --result-dir results/my_run \\ --tasks seed-tts-zeroshot-zh seed-tts-zeroshot-en """ from __future__ import annotations import argparse import json import logging import sys import time from dataclasses import dataclass from pathlib import Path import numpy as np import soundfile as sf from tqdm import tqdm from moss_tts_delay.llama_cpp import LlamaCppPipeline, PipelineConfig from moss_tts_delay.llama_cpp._constants import SAMPLE_RATE log = logging.getLogger(__name__) SEED_TTS_TASKS = [ "seed-tts-zeroshot-zh", "seed-tts-zeroshot-en", "seed-tts-zeroshot-hard-zh", ] CV3_TASKS = [ "cv3-crosslingual-en", "cv3-crosslingual-hard-en", "cv3-zeroshot-en", "cv3-zeroshot-hard-en", "cv3-crosslingual-zh", "cv3-crosslingual-hard-zh", "cv3-zeroshot-zh", "cv3-zeroshot-hard-zh", ] ALL_TASKS = SEED_TTS_TASKS + CV3_TASKS + ["demo-zh", "demo-en"] TASK_LANGUAGE = { "seed-tts-zeroshot-zh": "zh", "seed-tts-zeroshot-en": "en", "seed-tts-zeroshot-hard-zh": "zh", "cv3-crosslingual-en": "en", "cv3-crosslingual-hard-en": "en", "cv3-zeroshot-en": "en", "cv3-zeroshot-hard-en": "en", "cv3-crosslingual-zh": "zh", "cv3-crosslingual-hard-zh": "zh", "cv3-zeroshot-zh": "zh", "cv3-zeroshot-hard-zh": "zh", "demo-zh": "zh", "demo-en": "en", } @dataclass class CaseResult: task: str case_id: str success: bool audio_duration: float = 0.0 generation_time: float = 0.0 error: str = "" def discover_cases(benchmark_dir: Path, tasks: list[str]) -> list[tuple[str, str, Path, str]]: cases = [] for task in tasks: task_dir = benchmark_dir / task if not task_dir.is_dir(): log.warning("Task directory not found: %s", task_dir) continue for case_dir in sorted(task_dir.iterdir()): if not case_dir.is_dir(): continue prompt_wav = case_dir / "prompt.wav" label_txt = case_dir / "label.txt" if not label_txt.exists(): log.warning("Missing label.txt: %s", case_dir) continue text = label_txt.read_text().strip() cases.append((task, case_dir.name, prompt_wav, text)) return cases def run_batch( pipeline: LlamaCppPipeline, cases: list[tuple[str, str, Path, str]], result_dir: Path, max_cases: int = 0, skip_existing: bool = True, ) -> list[CaseResult]: results: list[CaseResult] = [] total = len(cases) if max_cases <= 0 else min(max_cases, len(cases)) cases = cases[:total] log.info("Running %d evaluation cases, output -> %s", total, result_dir) pbar = tqdm(cases, desc="Evaluation", unit="case", total=total, dynamic_ncols=True) for i, (task, case_id, prompt_wav, text) in enumerate(pbar): pbar.set_postfix_str(f"{task}/{case_id}") out_dir = result_dir / task / case_id out_wav = out_dir / "pred.wav" if skip_existing and out_wav.exists(): log.info("[%d/%d] %s/%s — skipped (exists)", i + 1, total, task, case_id) results.append(CaseResult(task=task, case_id=case_id, success=True)) continue log.info("[%d/%d] %s/%s — %s", i + 1, total, task, case_id, text[:60]) t0 = time.time() try: lang = TASK_LANGUAGE.get(task) ref_audio = str(prompt_wav) if prompt_wav.exists() else None waveform = pipeline.generate( text=text, reference_audio=ref_audio, language=lang, ) elapsed = time.time() - t0 if waveform.size == 0: results.append(CaseResult( task=task, case_id=case_id, success=False, generation_time=elapsed, error="empty waveform", )) continue out_dir.mkdir(parents=True, exist_ok=True) sf.write(str(out_wav), waveform, SAMPLE_RATE) audio_dur = len(waveform) / SAMPLE_RATE results.append(CaseResult( task=task, case_id=case_id, success=True, audio_duration=audio_dur, generation_time=elapsed, )) log.info( " -> %.2fs audio in %.2fs (RTF=%.2f)", audio_dur, elapsed, elapsed / max(audio_dur, 1e-6), ) except Exception as e: elapsed = time.time() - t0 log.error(" -> FAILED: %s", e) results.append(CaseResult( task=task, case_id=case_id, success=False, generation_time=elapsed, error=str(e), )) return results def write_summary(results: list[CaseResult], result_dir: Path) -> None: succeeded = [r for r in results if r.success] failed = [r for r in results if not r.success] per_task: dict[str, dict] = {} for r in results: if r.task not in per_task: per_task[r.task] = {"total": 0, "success": 0, "failed": 0, "total_audio_s": 0.0, "total_gen_s": 0.0} per_task[r.task]["total"] += 1 if r.success: per_task[r.task]["success"] += 1 per_task[r.task]["total_audio_s"] += r.audio_duration per_task[r.task]["total_gen_s"] += r.generation_time else: per_task[r.task]["failed"] += 1 for task, stats in per_task.items(): if stats["total_audio_s"] > 0: stats["avg_rtf"] = round(stats["total_gen_s"] / stats["total_audio_s"], 3) summary = { "total_cases": len(results), "succeeded": len(succeeded), "failed": len(failed), "per_task": per_task, } if failed: summary["failures"] = [ {"task": r.task, "case_id": r.case_id, "error": r.error} for r in failed ] summary_path = result_dir / "inference_summary.json" with open(summary_path, "w") as f: json.dump(summary, f, indent=2, ensure_ascii=False) log.info("Summary written to %s", summary_path) print("\n" + "=" * 60) print(" BATCH INFERENCE SUMMARY") print("=" * 60) print(f" Total: {len(results)}") print(f" Succeeded: {len(succeeded)}") print(f" Failed: {len(failed)}") for task, stats in per_task.items(): rtf = stats.get("avg_rtf", "N/A") print(f" {task}: {stats['success']}/{stats['total']} RTF={rtf}") print("=" * 60 + "\n") def main(): logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) parser = argparse.ArgumentParser( description="Batch TTS evaluation (llama.cpp backend)", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--config", required=True, help="Pipeline YAML config") parser.add_argument( "--benchmark-dir", default="/inspire/hdd/project/embodied-multimodality/public/speech_generation/data/eval/tts", ) parser.add_argument("--result-dir", required=True) parser.add_argument("--tasks", nargs="+", default=None) parser.add_argument("--suite", choices=["seed-tts", "cv3", "all"], default=None) parser.add_argument("--max-cases", type=int, default=0) parser.add_argument("--no-skip", action="store_true") parser.add_argument("--text-temp", type=float, default=None) parser.add_argument("--audio-temp", type=float, default=None) parser.add_argument("--audio-top-p", type=float, default=None) parser.add_argument("--audio-top-k", type=int, default=None) parser.add_argument("--audio-rep-penalty", type=float, default=None) parser.add_argument("--n-gpu-layers", type=int, default=None) parser.add_argument("--max-tokens", type=int, default=None) parser.add_argument("--heads-backend", choices=["auto", "numpy", "torch"], default=None) args = parser.parse_args() config = PipelineConfig.from_yaml(args.config) if args.text_temp is not None: config.text_temperature = args.text_temp if args.audio_temp is not None: config.audio_temperature = args.audio_temp if args.audio_top_p is not None: config.audio_top_p = args.audio_top_p if args.audio_top_k is not None: config.audio_top_k = args.audio_top_k if args.audio_rep_penalty is not None: config.audio_repetition_penalty = args.audio_rep_penalty if args.n_gpu_layers is not None: config.n_gpu_layers = args.n_gpu_layers if args.max_tokens is not None: config.max_new_tokens = args.max_tokens if args.heads_backend is not None: config.heads_backend = args.heads_backend if args.tasks: tasks = args.tasks elif args.suite == "seed-tts": tasks = SEED_TTS_TASKS elif args.suite == "cv3": tasks = CV3_TASKS else: tasks = ALL_TASKS for t in tasks: if t not in ALL_TASKS: log.error("Unknown task: %s. Valid tasks: %s", t, ALL_TASKS) sys.exit(1) benchmark_dir = Path(args.benchmark_dir) result_dir = Path(args.result_dir) result_dir.mkdir(parents=True, exist_ok=True) cases = discover_cases(benchmark_dir, tasks) if not cases: log.error("No cases found in %s for tasks %s", benchmark_dir, tasks) sys.exit(1) log.info("Discovered %d cases across %d tasks", len(cases), len(tasks)) run_meta = { "config": args.config, "benchmark_dir": str(benchmark_dir), "tasks": tasks, "sampling": { "text_temperature": config.text_temperature, "text_top_p": config.text_top_p, "text_top_k": config.text_top_k, "audio_temperature": config.audio_temperature, "audio_top_p": config.audio_top_p, "audio_top_k": config.audio_top_k, "audio_repetition_penalty": config.audio_repetition_penalty, }, "max_new_tokens": config.max_new_tokens, "backbone_gguf": config.backbone_gguf, "heads_backend": config.heads_backend, } with open(result_dir / "run_meta.json", "w") as f: json.dump(run_meta, f, indent=2, ensure_ascii=False) with LlamaCppPipeline(config) as pipeline: results = run_batch( pipeline, cases, result_dir, max_cases=args.max_cases, skip_existing=not args.no_skip, ) write_summary(results, result_dir) if __name__ == "__main__": main()