Habibi-TTS-ALG-Prod / scripts /01_epss_optimization.py
medyas's picture
Add EPSS optimization script (Priority 1)
eff0680 verified
#!/usr/bin/env python3
"""
================================================================================
Priority 1: EPSS + Sway Sampling Optimization for Habibi-TTS ALG
================================================================================
EPSS (Empirically Pruned Step Sampling) is built into F5-TTS since v1.1.20.
When `use_epss=True` (default in `CFM.sample()`) and `steps` is one of
{5, 6, 7, 10, 12, 16}, the model automatically uses the pruned timestep
schedule from the paper (arxiv:2505.19931).
Key findings from research:
- EPSS at 7 NFE: 4x speedup vs 32 NFE, ~0.04% WER degradation
- Sway Sampling is already enabled by default (sway_sampling_coef=-1.0)
- Combined BF16 + EPSS(7) + torch.compile: expected RTF ~0.018-0.022 on A10G
This script provides:
1. Optimized inference with EPSS
2. Benchmarking against baseline (32 NFE)
3. Quality comparison metrics
Usage:
python 01_epss_optimization.py \
--ref_audio reference.wav \
--ref_text "..." \
--gen_text "..." \
--nfe 7 \
--benchmark
================================================================================
"""
import argparse
import hashlib
import os
import sys
import time
import warnings
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
import torchaudio
from cached_path import cached_path
from f5_tts.infer.utils_infer import load_model, load_vocoder, preprocess_ref_audio_text
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
from habibi_tts.infer.utils_infer import infer_process
from habibi_tts.model.utils import dialect_id_map, text_list_formatter
from hydra.utils import get_class
from omegaconf import OmegaConf
warnings.filterwarnings("ignore")
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml")
CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors"
VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt"
TARGET_SAMPLE_RATE = 24000
N_MEL_CHANNELS = 100
HOP_LENGTH = 256
WIN_LENGTH = 1024
N_FFT = 1024
# EPSS supported NFE values (from f5_tts.model.utils.get_epss_timesteps)
EPSS_SUPPORTED = {5, 6, 7, 10, 12, 16}
# ---------------------------------------------------------------------------
# Model Loading
# ---------------------------------------------------------------------------
def load_habibi_alg(device=DEVICE, use_bf16=True, compile_model=False):
"""Load Habibi-TTS ALG specialized model with production optimizations."""
print(f"[LOAD] Loading Habibi-TTS ALG model on {device}...")
model_cfg = OmegaConf.load(MODEL_CFG_PATH)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
ckpt_file = str(cached_path(CKPT_URL))
vocab_file = str(cached_path(VOCAB_URL))
vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom")
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS),
mel_spec_kwargs=dict(
n_fft=N_FFT,
hop_length=HOP_LENGTH,
win_length=WIN_LENGTH,
n_mel_channels=N_MEL_CHANNELS,
target_sample_rate=TARGET_SAMPLE_RATE,
mel_spec_type="vocos",
),
odeint_kwargs=dict(method="euler"),
vocab_char_map=vocab_char_map,
).to(device)
# Load checkpoint
from safetensors.torch import load_file
checkpoint = load_file(ckpt_file, device=device)
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
# BF16 optimization (A10G supports BF16 natively)
if use_bf16 and device == "cuda":
print("[OPT] Converting model to BF16...")
model = model.to(torch.bfloat16)
# Vocos vocoder stays FP32 for quality
# torch.compile for ~20-30% additional speedup
if compile_model and device == "cuda":
print("[OPT] torch.compile(model, mode='reduce-overhead')...")
model = torch.compile(model, mode="reduce-overhead", fullgraph=False)
print("[LOAD] Model loaded successfully.")
return model
def load_vocoder_prod(device=DEVICE):
"""Load Vocos vocoder."""
print("[LOAD] Loading Vocos vocoder...")
return load_vocoder("vocos", is_local=False, local_path="", device=device)
# ---------------------------------------------------------------------------
# Inference with EPSS
# ---------------------------------------------------------------------------
def infer_epss(
ref_audio,
ref_text,
gen_text,
model_obj,
vocoder,
nfe_step=7,
cfg_strength=2.0,
sway_sampling_coef=-1.0,
speed=1.0,
device=DEVICE,
):
"""Run inference with EPSS optimization."""
use_epss = nfe_step in EPSS_SUPPORTED
if use_epss:
print(f"[EPSS] Using EPSS schedule for NFE={nfe_step}")
else:
print(f"[EPSS] NFE={nfe_step} not in EPSS schedule, using uniform timesteps")
audio, sr, _ = infer_process(
ref_audio,
ref_text,
gen_text,
model_obj,
vocoder,
mel_spec_type="vocos",
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
device=device,
dialect_id=dialect_id_map["ALG"],
)
return audio, sr
# ---------------------------------------------------------------------------
# Benchmarking
# ---------------------------------------------------------------------------
def benchmark_inference(
ref_audio,
ref_text,
gen_text,
model_obj,
vocoder,
nfe_values=[32, 16, 12, 10, 7, 6, 5],
num_runs=3,
warmup=1,
device=DEVICE,
):
"""Benchmark different NFE configurations."""
results = []
# Warmup
print(f"[BENCH] Warmup ({warmup} runs)...")
for _ in range(warmup):
infer_epss(ref_audio, ref_text, gen_text, model_obj, vocoder, nfe_step=7, device=device)
torch.cuda.synchronize() if device == "cuda" else None
for nfe in nfe_values:
print(f"\n[BENCH] NFE={nfe} ({num_runs} runs)...")
times = []
for run in range(num_runs):
torch.cuda.synchronize() if device == "cuda" else None
t0 = time.perf_counter()
audio, sr = infer_epss(
ref_audio, ref_text, gen_text, model_obj, vocoder, nfe_step=nfe, device=device
)
torch.cuda.synchronize() if device == "cuda" else None
t1 = time.perf_counter()
times.append(t1 - t0)
avg_time = np.mean(times)
std_time = np.std(times)
audio_duration = len(audio) / sr if audio is not None else 0
rtf = avg_time / audio_duration if audio_duration > 0 else float("inf")
result = {
"nfe": nfe,
"avg_time_sec": avg_time,
"std_time_sec": std_time,
"audio_duration_sec": audio_duration,
"rtf": rtf,
"speedup_vs_32": None,
}
results.append(result)
print(f" Time: {avg_time:.3f}s ± {std_time:.3f}s | Audio: {audio_duration:.2f}s | RTF: {rtf:.4f}")
# Calculate speedups relative to NFE=32
baseline_rtf = next(r["rtf"] for r in results if r["nfe"] == 32)
for r in results:
r["speedup_vs_32"] = baseline_rtf / r["rtf"] if r["rtf"] > 0 else 0
return results
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="EPSS Optimization for Habibi-TTS ALG")
parser.add_argument("--ref_audio", required=True, help="Reference audio file")
parser.add_argument("--ref_text", required=True, help="Reference text")
parser.add_argument("--gen_text", required=True, help="Text to synthesize")
parser.add_argument("--nfe", type=int, default=7, help="NFE steps (5,6,7,10,12,16,32)")
parser.add_argument("--cfg_strength", type=float, default=2.0)
parser.add_argument("--sway_coef", type=float, default=-1.0)
parser.add_argument("--speed", type=float, default=1.0)
parser.add_argument("--output", default="output_epss.wav", help="Output WAV file")
parser.add_argument("--benchmark", action="store_true", help="Run benchmark across NFE values")
parser.add_argument("--no_bf16", action="store_true", help="Disable BF16")
parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
parser.add_argument("--device", default=DEVICE)
args = parser.parse_args()
# Load model
model = load_habibi_alg(
device=args.device, use_bf16=not args.no_bf16, compile_model=args.compile
)
vocoder = load_vocoder_prod(device=args.device)
# Preprocess reference
ref_audio, ref_text = preprocess_ref_audio_text(args.ref_audio, args.ref_text)
if args.benchmark:
results = benchmark_inference(
ref_audio,
ref_text,
args.gen_text,
model,
vocoder,
nfe_values=[32, 16, 12, 10, 7, 6, 5],
device=args.device,
)
print("\n" + "=" * 70)
print("BENCHMARK RESULTS")
print("=" * 70)
print(f"{'NFE':>5} | {'Time(s)':>10} | {'Audio(s)':>10} | {'RTF':>8} | {'Speedup':>8}")
print("-" * 70)
for r in results:
print(
f"{r['nfe']:>5} | {r['avg_time_sec']:>10.3f} | {r['audio_duration_sec']:>10.2f} | {r['rtf']:>8.4f} | {r['speedup_vs_32']:>8.2f}x"
)
else:
# Single inference
print(f"\n[INFO] Running inference with NFE={args.nfe}...")
t0 = time.perf_counter()
audio, sr = infer_epss(
ref_audio,
ref_text,
args.gen_text,
model,
vocoder,
nfe_step=args.nfe,
cfg_strength=args.cfg_strength,
sway_sampling_coef=args.sway_coef,
speed=args.speed,
device=args.device,
)
t1 = time.perf_counter()
audio_duration = len(audio) / sr
rtf = (t1 - t0) / audio_duration
print(f"[DONE] Generated {audio_duration:.2f}s audio in {t1-t0:.3f}s (RTF={rtf:.4f})")
sf.write(args.output, audio, sr)
print(f"[SAVE] Saved to {args.output}")
if __name__ == "__main__":
main()