| """ |
| compare_onnx_vs_pytorch.py β Finnish Chatterbox ONNX browser-worker simulator |
| |
| SPEC: inference_example.py is the behavioral ground truth. |
| This script must produce equivalent output via ONNX (= what the browser WebGPU worker does). |
| |
| TWO MODES |
| --mode parity (default) Full pipeline: PyTorch vs ONNX side-by-side, quality report. |
| --mode debug Component-level comparison of intermediate tensors. |
| |
| REQUIREMENTS |
| - CUDA is required (raises RuntimeError if not available) |
| - Runs are deterministic via --seed |
| |
| KNOWN DIFFERENCES from inference_example.py (must be justified here): |
| 1. T3 conditioning (cond_emb): |
| PyTorch: cond_enc(T3Cond{ve_embed, prompt_tokens, exaggeration}) with Finnish weights |
| ONNX: precomputed finnish_cond_emb.bin (Finnish cond_enc not yet exported as ONNX) |
| Impact: conditioning is fixed to the voice used when exporting the .bin file. |
| For custom reference audio, a finnish_cond_enc.onnx export is needed (see export_finnish_embeddings.py). |
| 2. Watermarking: PyTorch applies Perth watermark; ONNX skips it (Perth not in ONNX pipeline). |
| 3. Speaker embedding: PyTorch T3 uses Perth WavLM (256-dim) for cond_enc. |
| ONNX speech_encoder uses CAMPPlus (192-dim) for conditional_decoder. |
| Both are correct: PyTorch S3Gen also uses CAMPPlus internally. |
| 4. prompt_tokens MUST be prepended: conditional_decoder expects [prompt_tokens | generated_tokens]. |
| Without this, the decoder produces ~0.18s of noise. Verified empirically. |
| |
| Usage: |
| # Full parity check |
| LD_LIBRARY_PATH=... conda run -n chatterbox-onnx python compare_onnx_vs_pytorch.py |
| |
| # Component debug |
| LD_LIBRARY_PATH=... conda run -n chatterbox-onnx python compare_onnx_vs_pytorch.py --mode debug |
| |
| # With analyze_audio for MOS |
| LD_LIBRARY_PATH=... conda run -n chatterbox-onnx python compare_onnx_vs_pytorch.py --mode parity --analyze |
| """ |
|
|
| import argparse, os, sys, time |
| import numpy as np |
| import soundfile as sf |
| import requests |
| from pathlib import Path |
|
|
| |
| TEXT = "Tervetuloa kokeilemaan hienoviritettyΓ€ suomenkielistΓ€ Chatterbox-puhesynteesiΓ€." |
| REF_AUDIO = "Chatterbox-Finnish/samples/reference_finnish.wav" |
| FINETUNED_W = "Chatterbox-Finnish/models/best_finnish_multilingual_cp986.safetensors" |
| PRETRAINED_DIR = "Chatterbox-Finnish/pretrained_models" |
| OUT_DIR = Path("_cmp"); OUT_DIR.mkdir(exist_ok=True) |
|
|
| |
| REPETITION_PENALTY = 1.2 |
| TEMPERATURE = 0.8 |
| EXAGGERATION = 0.6 |
| CFG_WEIGHT = 0.5 |
| MIN_P = 0.05 |
|
|
| |
| HF_BASE = "onnx-community/chatterbox-multilingual-ONNX" |
| HF_FI = "RASMUS/Chatterbox-Finnish-ONNX" |
| CACHE = Path("_onnx_cache"); CACHE.mkdir(exist_ok=True) |
|
|
| |
| START_SPEECH = 6561 |
| STOP_SPEECH = 6562 |
| SOT_TEXT = 255 |
| EOT_TEXT = 0 |
|
|
| GROQ_KEY = os.environ.get("GROQ_API_KEY", "") |
|
|
|
|
| |
|
|
| def require_cuda(): |
| import onnxruntime as ort |
| available = ort.get_available_providers() |
| if "CUDAExecutionProvider" not in available: |
| raise RuntimeError( |
| "CUDA provider not available! " |
| "Set LD_LIBRARY_PATH to include the bundled cuDNN and retry.\n" |
| "Example:\n" |
| " export LD_LIBRARY_PATH=/opt/conda/envs/chatterbox-onnx/lib/python3.11/" |
| "site-packages/nvidia/cudnn/lib:$LD_LIBRARY_PATH" |
| ) |
| return ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
|
|
|
| def hf_download(repo_id: str, filename: str) -> str: |
| from huggingface_hub import hf_hub_download |
| return hf_hub_download(repo_id=repo_id, filename=filename, |
| local_dir=str(CACHE), local_dir_use_symlinks=False) |
|
|
|
|
| def save_wav(arr: np.ndarray, path: str, sr: int = 24000): |
| sf.write(path, arr, sr) |
| dur = arr.shape[0] / sr |
| peak = float(np.abs(arr).max()) |
| print(f" saved β {path} ({dur:.2f}s, peak={peak:.4f})") |
|
|
|
|
| def transcribe(wav_path: str, lang: str = "fi") -> str: |
| if not GROQ_KEY: |
| return "(no GROQ_API_KEY)" |
| with open(wav_path, "rb") as f: |
| r = requests.post( |
| "https://api.groq.com/openai/v1/audio/transcriptions", |
| headers={"Authorization": f"Bearer {GROQ_KEY}"}, |
| files={"file": (os.path.basename(wav_path), f, "audio/wav")}, |
| data={"model": "whisper-large-v3", "language": lang, "response_format": "text"}, |
| ) |
| r.raise_for_status() |
| return r.text.strip() |
|
|
|
|
| |
|
|
| def apply_rep_penalty(logits: np.ndarray, generated: list, penalty: float) -> np.ndarray: |
| logits = logits.copy() |
| for tok in set(generated): |
| if logits[tok] > 0: |
| logits[tok] /= penalty |
| else: |
| logits[tok] *= penalty |
| return logits |
|
|
|
|
| def apply_min_p(logits: np.ndarray, p: float = 0.05) -> np.ndarray: |
| logits = logits.copy() |
| probs = np.exp(logits - logits.max()) |
| probs /= probs.sum() |
| logits[probs < probs.max() * p] = -1e9 |
| return logits |
|
|
|
|
| def sample_temperature(logits: np.ndarray, temperature: float) -> int: |
| logits = logits / temperature |
| logits -= logits.max() |
| probs = np.exp(logits) |
| probs /= probs.sum() |
| return int(np.random.choice(len(probs), p=probs)) |
|
|
|
|
| |
|
|
| def run_pytorch(seed: int) -> str: |
| print("\n" + "=" * 65) |
| print("1. PYTORCH INFERENCE (inference_example.py spec)") |
| print("=" * 65) |
| import torch |
| from safetensors.torch import load_file |
|
|
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| sys.path.insert(0, "Chatterbox-Finnish") |
| from src.chatterbox_.tts import ChatterboxTTS |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=device) |
| ckpt = load_file(FINETUNED_W) |
| t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()} |
| engine.t3.load_state_dict(t3_state, strict=False) |
| print(f" device={device} text='{TEXT[:60]}...'") |
|
|
| t0 = time.time() |
| wav = engine.generate( |
| text=TEXT, |
| audio_prompt_path=REF_AUDIO, |
| repetition_penalty=REPETITION_PENALTY, |
| temperature=TEMPERATURE, |
| exaggeration=EXAGGERATION, |
| cfg_weight=CFG_WEIGHT, |
| min_p=MIN_P, |
| ) |
| elapsed = time.time() - t0 |
|
|
| arr = wav.squeeze().cpu().numpy() |
| out = str(OUT_DIR / "pytorch_output.wav") |
| save_wav(arr, out, engine.sr) |
| print(f" inference time: {elapsed:.1f}s") |
| return out |
|
|
|
|
| |
|
|
| def run_onnx(seed: int) -> str: |
| print("\n" + "=" * 65) |
| print("2. ONNX INFERENCE (browser-worker simulator)") |
| print("=" * 65) |
| import onnxruntime as ort |
| import librosa |
|
|
| np.random.seed(seed) |
| providers = require_cuda() |
|
|
| opts = ort.SessionOptions() |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
| |
| print(" loading ONNX models (cached after first run)...") |
| sess_se = ort.InferenceSession( |
| hf_download(HF_BASE, "onnx/speech_encoder.onnx"), |
| sess_options=opts, providers=providers) |
| hf_download(HF_BASE, "onnx/speech_encoder.onnx_data") |
| sess_et = ort.InferenceSession( |
| hf_download(HF_BASE, "onnx/embed_tokens.onnx"), |
| sess_options=opts, providers=providers) |
| hf_download(HF_BASE, "onnx/embed_tokens.onnx_data") |
| sess_lm = ort.InferenceSession( |
| hf_download(HF_FI, "onnx/language_model.onnx"), |
| sess_options=opts, providers=providers) |
| hf_download(HF_FI, "onnx/language_model.onnx_data") |
| sess_cd = ort.InferenceSession( |
| hf_download(HF_BASE, "onnx/conditional_decoder.onnx"), |
| sess_options=opts, providers=providers) |
| hf_download(HF_BASE, "onnx/conditional_decoder.onnx_data") |
|
|
| |
| cond_emb_bin = hf_download(HF_FI, "onnx/finnish_cond_emb.bin") |
| cond_emb = np.frombuffer(open(cond_emb_bin, "rb").read(), dtype=np.float32).reshape(1, 34, 1024) |
| print(f" cond_emb: {cond_emb.shape} [Finnish precomputed; see KNOWN DIFFERENCES #1]") |
|
|
| |
| |
| |
| ref_24k, _ = librosa.load(REF_AUDIO, sr=24000) |
| se_out = sess_se.run(None, {"audio_values": ref_24k[np.newaxis, :].astype(np.float32)}) |
| |
| prompt_tokens = se_out[1] |
| speaker_emb = se_out[2] |
| speaker_features = se_out[3] |
| print(f" prompt_tokens: {prompt_tokens.shape} speaker_emb: {speaker_emb.shape} speaker_features: {speaker_features.shape}") |
|
|
| |
| sys.path.insert(0, "Chatterbox-Finnish") |
| from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer |
| from src.chatterbox_.tts import punc_norm |
|
|
| tok = EnTokenizer(os.path.join(PRETRAINED_DIR, "tokenizer.json")) |
| normed = punc_norm(TEXT) |
| token_ids = tok.encode(normed) |
| text_ids = np.array([[SOT_TEXT] + token_ids + [EOT_TEXT]], dtype=np.int64) |
| print(f" text tokens: {text_ids.shape} '{normed[:50]}...'") |
|
|
| exag = np.array([EXAGGERATION], dtype=np.float32) |
|
|
| |
| pos_ids = np.arange(text_ids.shape[1], dtype=np.int64)[np.newaxis, :] |
| text_embeds = sess_et.run(None, { |
| "input_ids": text_ids, |
| "position_ids": pos_ids, |
| "exaggeration": exag, |
| })[0] |
|
|
| |
| bos_emb = sess_et.run(None, { |
| "input_ids": np.array([[START_SPEECH]], dtype=np.int64), |
| "position_ids": np.array([[0]], dtype=np.int64), |
| "exaggeration": exag, |
| })[0] |
|
|
| |
| |
| prefill_c = np.concatenate([cond_emb, text_embeds, bos_emb], axis=1) |
| prefill_u = np.concatenate([cond_emb, np.zeros_like(text_embeds), bos_emb], axis=1) |
|
|
| |
| kv_key = "past_key_values.0.key" |
| kv_input = next(inp for inp in sess_lm.get_inputs() if inp.name == kv_key) |
| kv_dtype = np.float16 if "float16" in kv_input.type else np.float32 |
| kv_empty = np.zeros((1, 16, 0, 64), dtype=kv_dtype) |
| n_layers = 30 |
|
|
| def make_kv_feeds(kv): |
| return {f"past_key_values.{i}.{kv_}": kv[i][j] |
| for i in range(n_layers) for j, kv_ in enumerate(("key", "value"))} |
|
|
| kv_c = [(kv_empty.copy(), kv_empty.copy()) for _ in range(n_layers)] |
| kv_u = [(kv_empty.copy(), kv_empty.copy()) for _ in range(n_layers)] |
|
|
| def lm_step(emb, mask, kv): |
| feeds = {"inputs_embeds": emb, "attention_mask": mask} |
| feeds.update(make_kv_feeds(kv)) |
| outs = sess_lm.run(None, feeds) |
| new_kv = [(outs[1 + i * 2], outs[1 + i * 2 + 1]) for i in range(n_layers)] |
| return outs[0], new_kv |
|
|
| |
| print(" prefilling (cond + uncond)...") |
| t0 = time.time() |
| lc, kv_c = lm_step(prefill_c, np.ones((1, prefill_c.shape[1]), dtype=np.int64), kv_c) |
| lu, kv_u = lm_step(prefill_u, np.ones((1, prefill_u.shape[1]), dtype=np.int64), kv_u) |
| print(f" prefill done ({time.time() - t0:.1f}s)") |
|
|
| |
| def get_next_token(lc, lu, generated): |
| lc_f = lc[0, -1].astype(np.float32) |
| lu_f = lu[0, -1].astype(np.float32) |
| logits = lc_f + CFG_WEIGHT * (lc_f - lu_f) |
| logits = apply_rep_penalty(logits, generated, REPETITION_PENALTY) |
| logits = apply_min_p(logits, MIN_P) |
| return sample_temperature(logits, TEMPERATURE) |
|
|
| generated = [START_SPEECH] |
| speech_tokens = [] |
|
|
| token = get_next_token(lc, lu, generated) |
| generated.append(token) |
| if token < START_SPEECH: |
| speech_tokens.append(token) |
|
|
| print(f" generating tokens (max 800 steps)...") |
| t0 = time.time() |
| for step in range(1, 800): |
| last_emb = sess_et.run(None, { |
| "input_ids": np.array([[generated[-1]]], dtype=np.int64), |
| "position_ids": np.array([[step]], dtype=np.int64), |
| "exaggeration": exag, |
| })[0] |
| mask_c = np.ones((1, kv_c[0][0].shape[2] + 1), dtype=np.int64) |
| mask_u = np.ones((1, kv_u[0][0].shape[2] + 1), dtype=np.int64) |
| lc, kv_c = lm_step(last_emb, mask_c, kv_c) |
| lu, kv_u = lm_step(last_emb, mask_u, kv_u) |
| token = get_next_token(lc, lu, generated) |
|
|
| if token == STOP_SPEECH and len(speech_tokens) >= 40: |
| print(f" EOS at step {step} ({len(speech_tokens)} speech tokens)") |
| break |
|
|
| generated.append(token) |
| if token < START_SPEECH: |
| speech_tokens.append(token) |
|
|
| if (step + 1) % 100 == 0: |
| elapsed = time.time() - t0 |
| print(f" step {step + 1:4d}: {len(speech_tokens):3d} speech tokens ({(step+1)/elapsed:.1f} tok/s)") |
|
|
| elapsed = time.time() - t0 |
| print(f" generation: {elapsed:.1f}s β {len(speech_tokens)} speech tokens") |
|
|
| |
| |
| |
| |
| |
| generated_arr = np.array([speech_tokens], dtype=np.int64) |
| decoder_input = np.concatenate([prompt_tokens, generated_arr], axis=1) |
| print(f" decoder: {prompt_tokens.shape[1]} prompt + {len(speech_tokens)} generated = {decoder_input.shape[1]} total") |
|
|
| wav_out = sess_cd.run(None, { |
| "speech_tokens": decoder_input, |
| "speaker_embeddings": speaker_emb, |
| "speaker_features": speaker_features, |
| })[0] |
|
|
| arr_raw = wav_out.squeeze().astype(np.float32) |
| peak = float(np.abs(arr_raw).max()) |
| print(f" raw output peak={peak:.4f}") |
|
|
| |
| raw_out = str(OUT_DIR / "onnx_output_raw.wav") |
| save_wav(arr_raw, raw_out, 24000) |
|
|
| |
| arr = arr_raw.copy() |
| if peak < 0.01: |
| print(f" WARNING: very low amplitude, auto-normalizing") |
| arr = arr * (0.9 / peak) |
| elif peak > 0.9: |
| arr = arr * (0.9 / peak) |
| arr = np.clip(arr, -1.0, 1.0) |
|
|
| out = str(OUT_DIR / "onnx_output.wav") |
| save_wav(arr, out, 24000) |
| return out |
|
|
|
|
| |
|
|
| def compare(pytorch_wav: str, onnx_wav: str, analyze: bool = False): |
| print("\n" + "=" * 65) |
| print("3. QUALITY COMPARISON") |
| print("=" * 65) |
| print(f" ref text: '{TEXT}'") |
|
|
| pt_tx = transcribe(pytorch_wav) |
| on_tx = transcribe(onnx_wav) |
| bl_tx = transcribe("Chatterbox-Finnish/output_finnish.wav") if os.path.exists("Chatterbox-Finnish/output_finnish.wav") else None |
|
|
| print(f"\n PyTorch: '{pt_tx}'") |
| print(f" ONNX: '{on_tx}'") |
| if bl_tx: |
| print(f" Baseline: '{bl_tx}'") |
|
|
| if analyze: |
| import subprocess |
| print("\n Running analyze_audio.py...") |
| subprocess.run([ |
| sys.executable, "analyze_audio.py", |
| pytorch_wav, onnx_wav, |
| "--label-a", "PyTorch", "--label-b", "ONNX", |
| "--ref-text", TEXT, "--lang", "fi", |
| ]) |
|
|
|
|
| |
|
|
| def run_debug(seed: int): |
| """Inspect intermediate tensors to verify ONNX matches PyTorch components.""" |
| print("\n" + "=" * 65) |
| print("DEBUG MODE: component-level tensor comparison") |
| print("=" * 65) |
| import torch, onnxruntime as ort, librosa |
| from safetensors.torch import load_file |
|
|
| np.random.seed(seed) |
| providers = require_cuda() |
| sys.path.insert(0, "Chatterbox-Finnish") |
| from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer |
| from src.chatterbox_.tts import ChatterboxTTS, punc_norm |
|
|
| device = "cuda" |
|
|
| |
| engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=device) |
| ckpt = load_file(FINETUNED_W) |
| t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()} |
| engine.t3.load_state_dict(t3_state, strict=False) |
|
|
| |
| opts = ort.SessionOptions() |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| sess_se = ort.InferenceSession(hf_download(HF_BASE, "onnx/speech_encoder.onnx"), opts, providers) |
| hf_download(HF_BASE, "onnx/speech_encoder.onnx_data") |
| sess_et = ort.InferenceSession(hf_download(HF_BASE, "onnx/embed_tokens.onnx"), opts, providers) |
| hf_download(HF_BASE, "onnx/embed_tokens.onnx_data") |
|
|
| print("\nββ [1] Speech encoder: ONNX audio_tokens vs PyTorch s3gen prompt_tokens ββ") |
| ref_24k, _ = librosa.load(REF_AUDIO, sr=24000) |
| se_out = sess_se.run(None, {"audio_values": ref_24k[np.newaxis, :].astype(np.float32)}) |
| onnx_prompt_tokens = se_out[1] |
| onnx_speaker_emb = se_out[2] |
| onnx_speaker_features = se_out[3] |
|
|
| with torch.no_grad(): |
| engine.prepare_conditionals(REF_AUDIO, exaggeration=EXAGGERATION) |
| pt_ref_dict = engine.conds.gen |
| pt_prompt_tokens = pt_ref_dict["prompt_token"].cpu().numpy() |
| pt_speaker_feat = pt_ref_dict["prompt_feat"].cpu().numpy() |
| pt_speaker_emb = pt_ref_dict["embedding"].cpu().numpy() |
|
|
| print(f" ONNX prompt_tokens: {onnx_prompt_tokens.shape} match={np.array_equal(onnx_prompt_tokens, pt_prompt_tokens)}") |
| print(f" ONNX speaker_emb: {onnx_speaker_emb.shape} cosine={_cosine(onnx_speaker_emb, pt_speaker_emb):.6f}") |
| print(f" ONNX speaker_features: {onnx_speaker_features.shape} vs PyTorch {pt_speaker_feat.shape}") |
| if onnx_speaker_features.shape == pt_speaker_feat.shape: |
| diff = np.abs(onnx_speaker_features - pt_speaker_feat).max() |
| print(f" max_diff={diff:.6f}") |
|
|
| print("\nββ [2] Text tokenization: ONNX EnTokenizer vs PyTorch ββ") |
| tok = EnTokenizer(os.path.join(PRETRAINED_DIR, "tokenizer.json")) |
| normed = punc_norm(TEXT) |
| onnx_ids = [SOT_TEXT] + tok.encode(normed) + [EOT_TEXT] |
|
|
| with torch.no_grad(): |
| pt_ids = engine.tokenizer.text_to_tokens(normed) |
| pt_ids_list = pt_ids[0].tolist() |
| |
| pt_ids_padded = [engine.t3.hp.start_text_token] + pt_ids_list + [engine.t3.hp.stop_text_token] |
|
|
| match = onnx_ids == pt_ids_padded |
| print(f" ONNX ids ({len(onnx_ids)}): {onnx_ids[:8]}...") |
| print(f" PyTorch ({len(pt_ids_padded)}): {pt_ids_padded[:8]}...") |
| print(f" Exact match: {match}") |
|
|
| print("\nββ [3] embed_tokens: ONNX vs PyTorch ββ") |
| exag = np.array([EXAGGERATION], dtype=np.float32) |
| onnx_ids_arr = np.array([onnx_ids], dtype=np.int64) |
| pos_arr = np.arange(len(onnx_ids), dtype=np.int64)[np.newaxis, :] |
| onnx_embeds = sess_et.run(None, {"input_ids": onnx_ids_arr, "position_ids": pos_arr, "exaggeration": exag})[0] |
|
|
| with torch.no_grad(): |
| pt_embeds = engine.t3.text_emb(torch.tensor(onnx_ids_arr, device=device)).cpu().numpy() |
|
|
| diff = np.abs(onnx_embeds - pt_embeds).max() |
| cos = _cosine(onnx_embeds.flatten(), pt_embeds.flatten()) |
| print(f" embed_tokens max_diff={diff:.6f} cosine={cos:.6f}") |
| if diff > 0.01: |
| print(" WARNING: embed_tokens differ β Finnish model may need its own embed_tokens.onnx export") |
|
|
|
|
| def _cosine(a, b): |
| a, b = np.array(a).flatten(), np.array(b).flatten() |
| return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser(description="Finnish Chatterbox ONNX browser-worker simulator") |
| p.add_argument("--mode", choices=["parity", "debug"], default="parity") |
| p.add_argument("--seed", type=int, default=42, help="Random seed for determinism") |
| p.add_argument("--analyze", action="store_true", help="Run analyze_audio.py for MOS scoring") |
| p.add_argument("--skip-pytorch", action="store_true", help="Skip PyTorch run, use existing _cmp/pytorch_output.wav") |
| args = p.parse_args() |
|
|
| if args.mode == "debug": |
| run_debug(args.seed) |
| else: |
| if args.skip_pytorch: |
| pt_wav = str(OUT_DIR / "pytorch_output.wav") |
| assert Path(pt_wav).exists(), f"--skip-pytorch but {pt_wav} not found" |
| print(f" using existing {pt_wav}") |
| else: |
| pt_wav = run_pytorch(args.seed) |
|
|
| onnx_wav = run_onnx(args.seed) |
| compare(pt_wav, onnx_wav, analyze=args.analyze) |
| print(f"\nDone. Audio files in {OUT_DIR}/") |
| print(" Next: python analyze_audio.py _cmp/pytorch_output.wav _cmp/onnx_output.wav") |
|
|