#!/usr/bin/env python3 """ ================================================================================ Priority 1: EPSS + Sway Sampling Optimization for Habibi-TTS ALG ================================================================================ EPSS (Empirically Pruned Step Sampling) is built into F5-TTS since v1.1.20. When `use_epss=True` (default in `CFM.sample()`) and `steps` is one of {5, 6, 7, 10, 12, 16}, the model automatically uses the pruned timestep schedule from the paper (arxiv:2505.19931). Key findings from research: - EPSS at 7 NFE: 4x speedup vs 32 NFE, ~0.04% WER degradation - Sway Sampling is already enabled by default (sway_sampling_coef=-1.0) - Combined BF16 + EPSS(7) + torch.compile: expected RTF ~0.018-0.022 on A10G This script provides: 1. Optimized inference with EPSS 2. Benchmarking against baseline (32 NFE) 3. Quality comparison metrics Usage: python 01_epss_optimization.py \ --ref_audio reference.wav \ --ref_text "..." \ --gen_text "..." \ --nfe 7 \ --benchmark ================================================================================ """ import argparse import hashlib import os import sys import time import warnings from pathlib import Path import numpy as np import soundfile as sf import torch import torchaudio from cached_path import cached_path from f5_tts.infer.utils_infer import load_model, load_vocoder, preprocess_ref_audio_text from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer from habibi_tts.infer.utils_infer import infer_process from habibi_tts.model.utils import dialect_id_map, text_list_formatter from hydra.utils import get_class from omegaconf import OmegaConf warnings.filterwarnings("ignore") # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml") CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors" VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt" TARGET_SAMPLE_RATE = 24000 N_MEL_CHANNELS = 100 HOP_LENGTH = 256 WIN_LENGTH = 1024 N_FFT = 1024 # EPSS supported NFE values (from f5_tts.model.utils.get_epss_timesteps) EPSS_SUPPORTED = {5, 6, 7, 10, 12, 16} # --------------------------------------------------------------------------- # Model Loading # --------------------------------------------------------------------------- def load_habibi_alg(device=DEVICE, use_bf16=True, compile_model=False): """Load Habibi-TTS ALG specialized model with production optimizations.""" print(f"[LOAD] Loading Habibi-TTS ALG model on {device}...") model_cfg = OmegaConf.load(MODEL_CFG_PATH) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch ckpt_file = str(cached_path(CKPT_URL)) vocab_file = str(cached_path(VOCAB_URL)) vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom") model = CFM( transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS), mel_spec_kwargs=dict( n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, n_mel_channels=N_MEL_CHANNELS, target_sample_rate=TARGET_SAMPLE_RATE, mel_spec_type="vocos", ), odeint_kwargs=dict(method="euler"), vocab_char_map=vocab_char_map, ).to(device) # Load checkpoint from safetensors.torch import load_file checkpoint = load_file(ckpt_file, device=device) checkpoint = {"ema_model_state_dict": checkpoint} checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] model.load_state_dict(checkpoint["model_state_dict"]) del checkpoint torch.cuda.empty_cache() # BF16 optimization (A10G supports BF16 natively) if use_bf16 and device == "cuda": print("[OPT] Converting model to BF16...") model = model.to(torch.bfloat16) # Vocos vocoder stays FP32 for quality # torch.compile for ~20-30% additional speedup if compile_model and device == "cuda": print("[OPT] torch.compile(model, mode='reduce-overhead')...") model = torch.compile(model, mode="reduce-overhead", fullgraph=False) print("[LOAD] Model loaded successfully.") return model def load_vocoder_prod(device=DEVICE): """Load Vocos vocoder.""" print("[LOAD] Loading Vocos vocoder...") return load_vocoder("vocos", is_local=False, local_path="", device=device) # --------------------------------------------------------------------------- # Inference with EPSS # --------------------------------------------------------------------------- def infer_epss( ref_audio, ref_text, gen_text, model_obj, vocoder, nfe_step=7, cfg_strength=2.0, sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, ): """Run inference with EPSS optimization.""" use_epss = nfe_step in EPSS_SUPPORTED if use_epss: print(f"[EPSS] Using EPSS schedule for NFE={nfe_step}") else: print(f"[EPSS] NFE={nfe_step} not in EPSS schedule, using uniform timesteps") audio, sr, _ = infer_process( ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type="vocos", nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=speed, device=device, dialect_id=dialect_id_map["ALG"], ) return audio, sr # --------------------------------------------------------------------------- # Benchmarking # --------------------------------------------------------------------------- def benchmark_inference( ref_audio, ref_text, gen_text, model_obj, vocoder, nfe_values=[32, 16, 12, 10, 7, 6, 5], num_runs=3, warmup=1, device=DEVICE, ): """Benchmark different NFE configurations.""" results = [] # Warmup print(f"[BENCH] Warmup ({warmup} runs)...") for _ in range(warmup): infer_epss(ref_audio, ref_text, gen_text, model_obj, vocoder, nfe_step=7, device=device) torch.cuda.synchronize() if device == "cuda" else None for nfe in nfe_values: print(f"\n[BENCH] NFE={nfe} ({num_runs} runs)...") times = [] for run in range(num_runs): torch.cuda.synchronize() if device == "cuda" else None t0 = time.perf_counter() audio, sr = infer_epss( ref_audio, ref_text, gen_text, model_obj, vocoder, nfe_step=nfe, device=device ) torch.cuda.synchronize() if device == "cuda" else None t1 = time.perf_counter() times.append(t1 - t0) avg_time = np.mean(times) std_time = np.std(times) audio_duration = len(audio) / sr if audio is not None else 0 rtf = avg_time / audio_duration if audio_duration > 0 else float("inf") result = { "nfe": nfe, "avg_time_sec": avg_time, "std_time_sec": std_time, "audio_duration_sec": audio_duration, "rtf": rtf, "speedup_vs_32": None, } results.append(result) print(f" Time: {avg_time:.3f}s ± {std_time:.3f}s | Audio: {audio_duration:.2f}s | RTF: {rtf:.4f}") # Calculate speedups relative to NFE=32 baseline_rtf = next(r["rtf"] for r in results if r["nfe"] == 32) for r in results: r["speedup_vs_32"] = baseline_rtf / r["rtf"] if r["rtf"] > 0 else 0 return results # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="EPSS Optimization for Habibi-TTS ALG") parser.add_argument("--ref_audio", required=True, help="Reference audio file") parser.add_argument("--ref_text", required=True, help="Reference text") parser.add_argument("--gen_text", required=True, help="Text to synthesize") parser.add_argument("--nfe", type=int, default=7, help="NFE steps (5,6,7,10,12,16,32)") parser.add_argument("--cfg_strength", type=float, default=2.0) parser.add_argument("--sway_coef", type=float, default=-1.0) parser.add_argument("--speed", type=float, default=1.0) parser.add_argument("--output", default="output_epss.wav", help="Output WAV file") parser.add_argument("--benchmark", action="store_true", help="Run benchmark across NFE values") parser.add_argument("--no_bf16", action="store_true", help="Disable BF16") parser.add_argument("--compile", action="store_true", help="Enable torch.compile") parser.add_argument("--device", default=DEVICE) args = parser.parse_args() # Load model model = load_habibi_alg( device=args.device, use_bf16=not args.no_bf16, compile_model=args.compile ) vocoder = load_vocoder_prod(device=args.device) # Preprocess reference ref_audio, ref_text = preprocess_ref_audio_text(args.ref_audio, args.ref_text) if args.benchmark: results = benchmark_inference( ref_audio, ref_text, args.gen_text, model, vocoder, nfe_values=[32, 16, 12, 10, 7, 6, 5], device=args.device, ) print("\n" + "=" * 70) print("BENCHMARK RESULTS") print("=" * 70) print(f"{'NFE':>5} | {'Time(s)':>10} | {'Audio(s)':>10} | {'RTF':>8} | {'Speedup':>8}") print("-" * 70) for r in results: print( f"{r['nfe']:>5} | {r['avg_time_sec']:>10.3f} | {r['audio_duration_sec']:>10.2f} | {r['rtf']:>8.4f} | {r['speedup_vs_32']:>8.2f}x" ) else: # Single inference print(f"\n[INFO] Running inference with NFE={args.nfe}...") t0 = time.perf_counter() audio, sr = infer_epss( ref_audio, ref_text, args.gen_text, model, vocoder, nfe_step=args.nfe, cfg_strength=args.cfg_strength, sway_sampling_coef=args.sway_coef, speed=args.speed, device=args.device, ) t1 = time.perf_counter() audio_duration = len(audio) / sr rtf = (t1 - t0) / audio_duration print(f"[DONE] Generated {audio_duration:.2f}s audio in {t1-t0:.3f}s (RTF={rtf:.4f})") sf.write(args.output, audio, sr) print(f"[SAVE] Saved to {args.output}") if __name__ == "__main__": main()