Chatterbox-Finnish-ONNX / scripts /compare_onnx_vs_pytorch.py
RASMUS's picture
Add scripts/compare_onnx_vs_pytorch.py
79e2b98 verified
"""
compare_onnx_vs_pytorch.py β€” Finnish Chatterbox ONNX browser-worker simulator
SPEC: inference_example.py is the behavioral ground truth.
This script must produce equivalent output via ONNX (= what the browser WebGPU worker does).
TWO MODES
--mode parity (default) Full pipeline: PyTorch vs ONNX side-by-side, quality report.
--mode debug Component-level comparison of intermediate tensors.
REQUIREMENTS
- CUDA is required (raises RuntimeError if not available)
- Runs are deterministic via --seed
KNOWN DIFFERENCES from inference_example.py (must be justified here):
1. T3 conditioning (cond_emb):
PyTorch: cond_enc(T3Cond{ve_embed, prompt_tokens, exaggeration}) with Finnish weights
ONNX: precomputed finnish_cond_emb.bin (Finnish cond_enc not yet exported as ONNX)
Impact: conditioning is fixed to the voice used when exporting the .bin file.
For custom reference audio, a finnish_cond_enc.onnx export is needed (see export_finnish_embeddings.py).
2. Watermarking: PyTorch applies Perth watermark; ONNX skips it (Perth not in ONNX pipeline).
3. Speaker embedding: PyTorch T3 uses Perth WavLM (256-dim) for cond_enc.
ONNX speech_encoder uses CAMPPlus (192-dim) for conditional_decoder.
Both are correct: PyTorch S3Gen also uses CAMPPlus internally.
4. prompt_tokens MUST be prepended: conditional_decoder expects [prompt_tokens | generated_tokens].
Without this, the decoder produces ~0.18s of noise. Verified empirically.
Usage:
# Full parity check
LD_LIBRARY_PATH=... conda run -n chatterbox-onnx python compare_onnx_vs_pytorch.py
# Component debug
LD_LIBRARY_PATH=... conda run -n chatterbox-onnx python compare_onnx_vs_pytorch.py --mode debug
# With analyze_audio for MOS
LD_LIBRARY_PATH=... conda run -n chatterbox-onnx python compare_onnx_vs_pytorch.py --mode parity --analyze
"""
import argparse, os, sys, time
import numpy as np
import soundfile as sf
import requests
from pathlib import Path
# ── Config β€” mirrors inference_example.py exactly ─────────────────────────────
TEXT = "Tervetuloa kokeilemaan hienoviritettyΓ€ suomenkielistΓ€ Chatterbox-puhesynteesiΓ€."
REF_AUDIO = "Chatterbox-Finnish/samples/reference_finnish.wav"
FINETUNED_W = "Chatterbox-Finnish/models/best_finnish_multilingual_cp986.safetensors"
PRETRAINED_DIR = "Chatterbox-Finnish/pretrained_models"
OUT_DIR = Path("_cmp"); OUT_DIR.mkdir(exist_ok=True)
# inference_example.py params
REPETITION_PENALTY = 1.2
TEMPERATURE = 0.8
EXAGGERATION = 0.6
CFG_WEIGHT = 0.5
MIN_P = 0.05
# ONNX model repos
HF_BASE = "onnx-community/chatterbox-multilingual-ONNX"
HF_FI = "RASMUS/Chatterbox-Finnish-ONNX"
CACHE = Path("_onnx_cache"); CACHE.mkdir(exist_ok=True)
# Token constants (matches s3tokenizer: SPEECH_VOCAB_SIZE=6561, SOS=6561, EOS=6562)
START_SPEECH = 6561
STOP_SPEECH = 6562
SOT_TEXT = 255 # EnTokenizer start-of-text
EOT_TEXT = 0 # EnTokenizer end-of-text
GROQ_KEY = os.environ.get("GROQ_API_KEY", "")
# ── Helpers ───────────────────────────────────────────────────────────────────
def require_cuda():
import onnxruntime as ort
available = ort.get_available_providers()
if "CUDAExecutionProvider" not in available:
raise RuntimeError(
"CUDA provider not available! "
"Set LD_LIBRARY_PATH to include the bundled cuDNN and retry.\n"
"Example:\n"
" export LD_LIBRARY_PATH=/opt/conda/envs/chatterbox-onnx/lib/python3.11/"
"site-packages/nvidia/cudnn/lib:$LD_LIBRARY_PATH"
)
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
def hf_download(repo_id: str, filename: str) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename,
local_dir=str(CACHE), local_dir_use_symlinks=False)
def save_wav(arr: np.ndarray, path: str, sr: int = 24000):
sf.write(path, arr, sr)
dur = arr.shape[0] / sr
peak = float(np.abs(arr).max())
print(f" saved β†’ {path} ({dur:.2f}s, peak={peak:.4f})")
def transcribe(wav_path: str, lang: str = "fi") -> str:
if not GROQ_KEY:
return "(no GROQ_API_KEY)"
with open(wav_path, "rb") as f:
r = requests.post(
"https://api.groq.com/openai/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {GROQ_KEY}"},
files={"file": (os.path.basename(wav_path), f, "audio/wav")},
data={"model": "whisper-large-v3", "language": lang, "response_format": "text"},
)
r.raise_for_status()
return r.text.strip()
# ── Sampling (mirrors inference_example.py params) ────────────────────────────
def apply_rep_penalty(logits: np.ndarray, generated: list, penalty: float) -> np.ndarray:
logits = logits.copy()
for tok in set(generated):
if logits[tok] > 0:
logits[tok] /= penalty
else:
logits[tok] *= penalty
return logits
def apply_min_p(logits: np.ndarray, p: float = 0.05) -> np.ndarray:
logits = logits.copy()
probs = np.exp(logits - logits.max())
probs /= probs.sum()
logits[probs < probs.max() * p] = -1e9
return logits
def sample_temperature(logits: np.ndarray, temperature: float) -> int:
logits = logits / temperature
logits -= logits.max()
probs = np.exp(logits)
probs /= probs.sum()
return int(np.random.choice(len(probs), p=probs))
# ── 1. PyTorch inference (SPEC: inference_example.py) ────────────────────────
def run_pytorch(seed: int) -> str:
print("\n" + "=" * 65)
print("1. PYTORCH INFERENCE (inference_example.py spec)")
print("=" * 65)
import torch
from safetensors.torch import load_file
torch.manual_seed(seed)
np.random.seed(seed)
sys.path.insert(0, "Chatterbox-Finnish")
from src.chatterbox_.tts import ChatterboxTTS
device = "cuda" if torch.cuda.is_available() else "cpu"
engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=device)
ckpt = load_file(FINETUNED_W)
t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()}
engine.t3.load_state_dict(t3_state, strict=False)
print(f" device={device} text='{TEXT[:60]}...'")
t0 = time.time()
wav = engine.generate(
text=TEXT,
audio_prompt_path=REF_AUDIO,
repetition_penalty=REPETITION_PENALTY,
temperature=TEMPERATURE,
exaggeration=EXAGGERATION,
cfg_weight=CFG_WEIGHT,
min_p=MIN_P,
)
elapsed = time.time() - t0
arr = wav.squeeze().cpu().numpy()
out = str(OUT_DIR / "pytorch_output.wav")
save_wav(arr, out, engine.sr)
print(f" inference time: {elapsed:.1f}s")
return out
# ── 2. ONNX inference (browser-worker simulator) ─────────────────────────────
def run_onnx(seed: int) -> str:
print("\n" + "=" * 65)
print("2. ONNX INFERENCE (browser-worker simulator)")
print("=" * 65)
import onnxruntime as ort
import librosa
np.random.seed(seed)
providers = require_cuda()
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# ── Load ONNX models ──
print(" loading ONNX models (cached after first run)...")
sess_se = ort.InferenceSession(
hf_download(HF_BASE, "onnx/speech_encoder.onnx"),
sess_options=opts, providers=providers)
hf_download(HF_BASE, "onnx/speech_encoder.onnx_data")
sess_et = ort.InferenceSession(
hf_download(HF_BASE, "onnx/embed_tokens.onnx"),
sess_options=opts, providers=providers)
hf_download(HF_BASE, "onnx/embed_tokens.onnx_data")
sess_lm = ort.InferenceSession(
hf_download(HF_FI, "onnx/language_model.onnx"),
sess_options=opts, providers=providers)
hf_download(HF_FI, "onnx/language_model.onnx_data")
sess_cd = ort.InferenceSession(
hf_download(HF_BASE, "onnx/conditional_decoder.onnx"),
sess_options=opts, providers=providers)
hf_download(HF_BASE, "onnx/conditional_decoder.onnx_data")
# ── Conditioning: Finnish cond_emb (see KNOWN DIFFERENCES #1) ──
cond_emb_bin = hf_download(HF_FI, "onnx/finnish_cond_emb.bin")
cond_emb = np.frombuffer(open(cond_emb_bin, "rb").read(), dtype=np.float32).reshape(1, 34, 1024)
print(f" cond_emb: {cond_emb.shape} [Finnish precomputed; see KNOWN DIFFERENCES #1]")
# ── Speech encoder: reference audio at 24kHz (S3GEN_SR) ──
# Note: speech_encoder expects S3GEN_SR=24000, NOT 16kHz.
# PyTorch prepare_conditionals() also loads at S3GEN_SR=24000 then resamples for 16k separately.
ref_24k, _ = librosa.load(REF_AUDIO, sr=24000)
se_out = sess_se.run(None, {"audio_values": ref_24k[np.newaxis, :].astype(np.float32)})
# Outputs: audio_features[0], audio_tokens[1], speaker_embeddings[2], speaker_features[3]
prompt_tokens = se_out[1] # [1, N] β€” reference S3 tokens (see KNOWN DIFFERENCES #4)
speaker_emb = se_out[2] # [1, 192] β€” CAMPPlus x-vector (see KNOWN DIFFERENCES #3)
speaker_features = se_out[3] # [1, T, 80] β€” reference mel spec
print(f" prompt_tokens: {prompt_tokens.shape} speaker_emb: {speaker_emb.shape} speaker_features: {speaker_features.shape}")
# ── Tokenize: EnTokenizer matching PyTorch (punc_norm β†’ encode β†’ wrap SOT/EOT) ──
sys.path.insert(0, "Chatterbox-Finnish")
from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer
from src.chatterbox_.tts import punc_norm
tok = EnTokenizer(os.path.join(PRETRAINED_DIR, "tokenizer.json"))
normed = punc_norm(TEXT)
token_ids = tok.encode(normed)
text_ids = np.array([[SOT_TEXT] + token_ids + [EOT_TEXT]], dtype=np.int64) # [1, T]
print(f" text tokens: {text_ids.shape} '{normed[:50]}...'")
exag = np.array([EXAGGERATION], dtype=np.float32)
# ── Embed text tokens ──
pos_ids = np.arange(text_ids.shape[1], dtype=np.int64)[np.newaxis, :]
text_embeds = sess_et.run(None, {
"input_ids": text_ids,
"position_ids": pos_ids,
"exaggeration": exag,
})[0] # [1, T, 1024]
# ── Embed BOS speech token ──
bos_emb = sess_et.run(None, {
"input_ids": np.array([[START_SPEECH]], dtype=np.int64),
"position_ids": np.array([[0]], dtype=np.int64),
"exaggeration": exag,
})[0] # [1, 1, 1024]
# ── CFG prefill: conditioned and unconditioned (matching PyTorch cfg_weight=0.5) ──
# PyTorch: text_tokens duplicated β†’ batch[0]=cond, batch[1]=uncond; cond has real text, uncond has zeros
prefill_c = np.concatenate([cond_emb, text_embeds, bos_emb], axis=1)
prefill_u = np.concatenate([cond_emb, np.zeros_like(text_embeds), bos_emb], axis=1)
# ── KV cache setup ──
kv_key = "past_key_values.0.key"
kv_input = next(inp for inp in sess_lm.get_inputs() if inp.name == kv_key)
kv_dtype = np.float16 if "float16" in kv_input.type else np.float32
kv_empty = np.zeros((1, 16, 0, 64), dtype=kv_dtype)
n_layers = 30
def make_kv_feeds(kv):
return {f"past_key_values.{i}.{kv_}": kv[i][j]
for i in range(n_layers) for j, kv_ in enumerate(("key", "value"))}
kv_c = [(kv_empty.copy(), kv_empty.copy()) for _ in range(n_layers)]
kv_u = [(kv_empty.copy(), kv_empty.copy()) for _ in range(n_layers)]
def lm_step(emb, mask, kv):
feeds = {"inputs_embeds": emb, "attention_mask": mask}
feeds.update(make_kv_feeds(kv))
outs = sess_lm.run(None, feeds)
new_kv = [(outs[1 + i * 2], outs[1 + i * 2 + 1]) for i in range(n_layers)]
return outs[0], new_kv
# ── Prefill both streams ──
print(" prefilling (cond + uncond)...")
t0 = time.time()
lc, kv_c = lm_step(prefill_c, np.ones((1, prefill_c.shape[1]), dtype=np.int64), kv_c)
lu, kv_u = lm_step(prefill_u, np.ones((1, prefill_u.shape[1]), dtype=np.int64), kv_u)
print(f" prefill done ({time.time() - t0:.1f}s)")
# ── Autoregressive generation with CFG ──
def get_next_token(lc, lu, generated):
lc_f = lc[0, -1].astype(np.float32)
lu_f = lu[0, -1].astype(np.float32)
logits = lc_f + CFG_WEIGHT * (lc_f - lu_f)
logits = apply_rep_penalty(logits, generated, REPETITION_PENALTY)
logits = apply_min_p(logits, MIN_P)
return sample_temperature(logits, TEMPERATURE)
generated = [START_SPEECH]
speech_tokens = []
token = get_next_token(lc, lu, generated)
generated.append(token)
if token < START_SPEECH:
speech_tokens.append(token)
print(f" generating tokens (max 800 steps)...")
t0 = time.time()
for step in range(1, 800):
last_emb = sess_et.run(None, {
"input_ids": np.array([[generated[-1]]], dtype=np.int64),
"position_ids": np.array([[step]], dtype=np.int64),
"exaggeration": exag,
})[0]
mask_c = np.ones((1, kv_c[0][0].shape[2] + 1), dtype=np.int64)
mask_u = np.ones((1, kv_u[0][0].shape[2] + 1), dtype=np.int64)
lc, kv_c = lm_step(last_emb, mask_c, kv_c)
lu, kv_u = lm_step(last_emb, mask_u, kv_u)
token = get_next_token(lc, lu, generated)
if token == STOP_SPEECH and len(speech_tokens) >= 40:
print(f" EOS at step {step} ({len(speech_tokens)} speech tokens)")
break
generated.append(token)
if token < START_SPEECH:
speech_tokens.append(token)
if (step + 1) % 100 == 0:
elapsed = time.time() - t0
print(f" step {step + 1:4d}: {len(speech_tokens):3d} speech tokens ({(step+1)/elapsed:.1f} tok/s)")
elapsed = time.time() - t0
print(f" generation: {elapsed:.1f}s β†’ {len(speech_tokens)} speech tokens")
# ── Decode β†’ waveform ──
# VERIFIED: prompt_tokens MUST be prepended (see KNOWN DIFFERENCES #4).
# The conditional_decoder is a CosyVoice-style flow model: it processes
# [prompt | generated] and returns ONLY the generated portion as audio.
# Without prepending, output is ~(generated - prompt_len) β‰ˆ 0.18s of noise.
generated_arr = np.array([speech_tokens], dtype=np.int64)
decoder_input = np.concatenate([prompt_tokens, generated_arr], axis=1)
print(f" decoder: {prompt_tokens.shape[1]} prompt + {len(speech_tokens)} generated = {decoder_input.shape[1]} total")
wav_out = sess_cd.run(None, {
"speech_tokens": decoder_input,
"speaker_embeddings": speaker_emb,
"speaker_features": speaker_features,
})[0]
arr_raw = wav_out.squeeze().astype(np.float32)
peak = float(np.abs(arr_raw).max())
print(f" raw output peak={peak:.4f}")
# Save raw (unmodified decoder output) for listening/debugging
raw_out = str(OUT_DIR / "onnx_output_raw.wav")
save_wav(arr_raw, raw_out, 24000)
# Normalize to 0.9 peak to prevent clipping, then clip as safety net
arr = arr_raw.copy()
if peak < 0.01:
print(f" WARNING: very low amplitude, auto-normalizing")
arr = arr * (0.9 / peak)
elif peak > 0.9:
arr = arr * (0.9 / peak)
arr = np.clip(arr, -1.0, 1.0)
out = str(OUT_DIR / "onnx_output.wav")
save_wav(arr, out, 24000)
return out
# ── 3. Quality comparison ─────────────────────────────────────────────────────
def compare(pytorch_wav: str, onnx_wav: str, analyze: bool = False):
print("\n" + "=" * 65)
print("3. QUALITY COMPARISON")
print("=" * 65)
print(f" ref text: '{TEXT}'")
pt_tx = transcribe(pytorch_wav)
on_tx = transcribe(onnx_wav)
bl_tx = transcribe("Chatterbox-Finnish/output_finnish.wav") if os.path.exists("Chatterbox-Finnish/output_finnish.wav") else None
print(f"\n PyTorch: '{pt_tx}'")
print(f" ONNX: '{on_tx}'")
if bl_tx:
print(f" Baseline: '{bl_tx}'")
if analyze:
import subprocess
print("\n Running analyze_audio.py...")
subprocess.run([
sys.executable, "analyze_audio.py",
pytorch_wav, onnx_wav,
"--label-a", "PyTorch", "--label-b", "ONNX",
"--ref-text", TEXT, "--lang", "fi",
])
# ── 4. Component debug mode ───────────────────────────────────────────────────
def run_debug(seed: int):
"""Inspect intermediate tensors to verify ONNX matches PyTorch components."""
print("\n" + "=" * 65)
print("DEBUG MODE: component-level tensor comparison")
print("=" * 65)
import torch, onnxruntime as ort, librosa
from safetensors.torch import load_file
np.random.seed(seed)
providers = require_cuda()
sys.path.insert(0, "Chatterbox-Finnish")
from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer
from src.chatterbox_.tts import ChatterboxTTS, punc_norm
device = "cuda"
# Load PyTorch engine
engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=device)
ckpt = load_file(FINETUNED_W)
t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()}
engine.t3.load_state_dict(t3_state, strict=False)
# Load ONNX sessions
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_se = ort.InferenceSession(hf_download(HF_BASE, "onnx/speech_encoder.onnx"), opts, providers)
hf_download(HF_BASE, "onnx/speech_encoder.onnx_data")
sess_et = ort.InferenceSession(hf_download(HF_BASE, "onnx/embed_tokens.onnx"), opts, providers)
hf_download(HF_BASE, "onnx/embed_tokens.onnx_data")
print("\n── [1] Speech encoder: ONNX audio_tokens vs PyTorch s3gen prompt_tokens ──")
ref_24k, _ = librosa.load(REF_AUDIO, sr=24000)
se_out = sess_se.run(None, {"audio_values": ref_24k[np.newaxis, :].astype(np.float32)})
onnx_prompt_tokens = se_out[1]
onnx_speaker_emb = se_out[2]
onnx_speaker_features = se_out[3]
with torch.no_grad():
engine.prepare_conditionals(REF_AUDIO, exaggeration=EXAGGERATION)
pt_ref_dict = engine.conds.gen
pt_prompt_tokens = pt_ref_dict["prompt_token"].cpu().numpy()
pt_speaker_feat = pt_ref_dict["prompt_feat"].cpu().numpy() # [1, T, 80]
pt_speaker_emb = pt_ref_dict["embedding"].cpu().numpy() # [1, 192]
print(f" ONNX prompt_tokens: {onnx_prompt_tokens.shape} match={np.array_equal(onnx_prompt_tokens, pt_prompt_tokens)}")
print(f" ONNX speaker_emb: {onnx_speaker_emb.shape} cosine={_cosine(onnx_speaker_emb, pt_speaker_emb):.6f}")
print(f" ONNX speaker_features: {onnx_speaker_features.shape} vs PyTorch {pt_speaker_feat.shape}")
if onnx_speaker_features.shape == pt_speaker_feat.shape:
diff = np.abs(onnx_speaker_features - pt_speaker_feat).max()
print(f" max_diff={diff:.6f}")
print("\n── [2] Text tokenization: ONNX EnTokenizer vs PyTorch ──")
tok = EnTokenizer(os.path.join(PRETRAINED_DIR, "tokenizer.json"))
normed = punc_norm(TEXT)
onnx_ids = [SOT_TEXT] + tok.encode(normed) + [EOT_TEXT]
with torch.no_grad():
pt_ids = engine.tokenizer.text_to_tokens(normed)
pt_ids_list = pt_ids[0].tolist()
# PyTorch pads SOT/EOT during generate()
pt_ids_padded = [engine.t3.hp.start_text_token] + pt_ids_list + [engine.t3.hp.stop_text_token]
match = onnx_ids == pt_ids_padded
print(f" ONNX ids ({len(onnx_ids)}): {onnx_ids[:8]}...")
print(f" PyTorch ({len(pt_ids_padded)}): {pt_ids_padded[:8]}...")
print(f" Exact match: {match}")
print("\n── [3] embed_tokens: ONNX vs PyTorch ──")
exag = np.array([EXAGGERATION], dtype=np.float32)
onnx_ids_arr = np.array([onnx_ids], dtype=np.int64)
pos_arr = np.arange(len(onnx_ids), dtype=np.int64)[np.newaxis, :]
onnx_embeds = sess_et.run(None, {"input_ids": onnx_ids_arr, "position_ids": pos_arr, "exaggeration": exag})[0]
with torch.no_grad():
pt_embeds = engine.t3.text_emb(torch.tensor(onnx_ids_arr, device=device)).cpu().numpy()
diff = np.abs(onnx_embeds - pt_embeds).max()
cos = _cosine(onnx_embeds.flatten(), pt_embeds.flatten())
print(f" embed_tokens max_diff={diff:.6f} cosine={cos:.6f}")
if diff > 0.01:
print(" WARNING: embed_tokens differ β€” Finnish model may need its own embed_tokens.onnx export")
def _cosine(a, b):
a, b = np.array(a).flatten(), np.array(b).flatten()
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))
# ── Main ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Finnish Chatterbox ONNX browser-worker simulator")
p.add_argument("--mode", choices=["parity", "debug"], default="parity")
p.add_argument("--seed", type=int, default=42, help="Random seed for determinism")
p.add_argument("--analyze", action="store_true", help="Run analyze_audio.py for MOS scoring")
p.add_argument("--skip-pytorch", action="store_true", help="Skip PyTorch run, use existing _cmp/pytorch_output.wav")
args = p.parse_args()
if args.mode == "debug":
run_debug(args.seed)
else:
if args.skip_pytorch:
pt_wav = str(OUT_DIR / "pytorch_output.wav")
assert Path(pt_wav).exists(), f"--skip-pytorch but {pt_wav} not found"
print(f" using existing {pt_wav}")
else:
pt_wav = run_pytorch(args.seed)
onnx_wav = run_onnx(args.seed)
compare(pt_wav, onnx_wav, analyze=args.analyze)
print(f"\nDone. Audio files in {OUT_DIR}/")
print(" Next: python analyze_audio.py _cmp/pytorch_output.wav _cmp/onnx_output.wav")