""" export_finnish_embeddings.py Exports two ONNX components from the Finnish fine-tuned model that are currently missing from the browser pipeline: 1. embed_tokens.onnx — Finnish T3's text_emb + position embeddings (base version has slightly different weights) 2. voice_encoder.onnx — Perth WavLM VoiceEncoder → 256-dim speaker embedding (enables custom reference audio in browser without precomputed cond_emb) These two, combined with the already-uploaded finnish_cond_enc.onnx, give the browser the full custom-voice pipeline: voice_encoder → speaker_emb → cond_enc → cond_emb → language_model → decoder Outputs: _onnx_export/embed_tokens.onnx (small, ~140 MB) _onnx_export/voice_encoder.onnx (small, ~65 MB) Usage: cd /workspaces/work conda run -n chatterbox-onnx python export_finnish_embeddings.py """ import os, sys import numpy as np import torch import onnx from onnx.external_data_helper import convert_model_to_external_data from pathlib import Path from safetensors.torch import load_file sys.path.insert(0, "Chatterbox-Finnish") PRETRAINED_DIR = "Chatterbox-Finnish/pretrained_models" FINETUNED_W = "Chatterbox-Finnish/models/best_finnish_multilingual_cp986.safetensors" OUT_DIR = Path("_onnx_export"); OUT_DIR.mkdir(exist_ok=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_engine(): from src.chatterbox_.tts import ChatterboxTTS print(f" loading base engine ({DEVICE})...") engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=DEVICE) print(" injecting Finnish weights...") ckpt = load_file(FINETUNED_W) t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()} missing, unexpected = engine.t3.load_state_dict(t3_state, strict=False) print(f" loaded: {len(t3_state)-len(missing)} keys, missing={len(missing)}, unexpected={len(unexpected)}") return engine # ── 1. embed_tokens.onnx ───────────────────────────────────────────────────── def export_embed_tokens(engine): """ Wraps T3's token embedding table. Input: input_ids [batch, seq] int64 Output: embeds [batch, seq, 1024] float32 Note: T3 uses a single embedding table (text_emb) for both text tokens and speech tokens. The base ONNX repo exports this the same way. """ print("\n── export_embed_tokens ──") out_path = str(OUT_DIR / "embed_tokens.onnx") class EmbedTokens(torch.nn.Module): def __init__(self, emb: torch.nn.Embedding): super().__init__() self.emb = emb def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return self.emb(input_ids) # T3's text_emb is the token embedding table emb_module = EmbedTokens(engine.t3.text_emb).to(DEVICE).eval() vocab_size = engine.t3.text_emb.weight.shape[0] print(f" vocab_size={vocab_size}, embed_dim={engine.t3.text_emb.weight.shape[1]}") dummy_ids = torch.zeros(1, 5, dtype=torch.long, device=DEVICE) with torch.no_grad(): torch.onnx.export( emb_module, (dummy_ids,), out_path, input_names=["input_ids"], output_names=["embeds"], dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "embeds": {0: "batch", 1: "seq"}}, opset_version=17, do_constant_folding=True, ) # Validate model = onnx.load(out_path) onnx.checker.check_model(model) size_mb = os.path.getsize(out_path) / 1e6 print(f" ✓ {out_path} ({size_mb:.1f} MB)") return out_path # ── 2. voice_encoder.onnx ──────────────────────────────────────────────────── def export_voice_encoder(engine): """ Wraps the Perth WavLM VoiceEncoder. Input: audio [batch, samples] float32 (16kHz, variable length) Output: speaker_emb [batch, 256] float32 This allows the browser to compute speaker embeddings from arbitrary reference audio (instead of loading precomputed finnish_cond_emb.bin). """ print("\n── export_voice_encoder ──") out_path = str(OUT_DIR / "voice_encoder.onnx") ve = engine.ve.to(DEVICE).eval() # Perth VoiceEncoder takes raw audio at 16kHz # Try with 3 seconds of audio to expose dynamic shapes dummy_audio = torch.zeros(1, 48000, device=DEVICE) # 3s @ 16kHz with torch.no_grad(): torch.onnx.export( ve, (dummy_audio,), out_path, input_names=["audio"], output_names=["speaker_emb"], dynamic_axes={"audio": {0: "batch", 1: "samples"}, "speaker_emb": {0: "batch"}}, opset_version=17, do_constant_folding=True, ) model = onnx.load(out_path) onnx.checker.check_model(model) size_mb = os.path.getsize(out_path) / 1e6 print(f" ✓ {out_path} ({size_mb:.1f} MB)") return out_path # ── Validate both exports match PyTorch ────────────────────────────────────── def validate(engine, embed_path: str, ve_path: str): import onnxruntime as ort import librosa print("\n── Validation ──") providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] # Validate embed_tokens sess_et = ort.InferenceSession(embed_path, providers=providers) test_ids = np.array([[255, 284, 18, 22, 7, 0]], dtype=np.int64) # SOT + some tokens + EOT with torch.no_grad(): pt_emb = engine.t3.text_emb(torch.tensor(test_ids, device=DEVICE)).cpu().numpy() onnx_emb = sess_et.run(None, {"input_ids": test_ids})[0] max_diff = np.abs(pt_emb - onnx_emb).max() print(f" embed_tokens max_diff={max_diff:.6f} {'✓' if max_diff < 1e-4 else '✗ MISMATCH'}") # Validate voice_encoder ref_audio, ref_sr = librosa.load("Chatterbox-Finnish/samples/reference_finnish.wav", sr=None) ref_16k = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=16000).astype(np.float32) ref_input_np = ref_16k[np.newaxis, :] ref_input_pt = torch.tensor(ref_input_np, device=DEVICE) sess_ve = ort.InferenceSession(ve_path, providers=providers) with torch.no_grad(): pt_spk = engine.ve(ref_input_pt).cpu().numpy() onnx_spk = sess_ve.run(None, {"audio": ref_input_np})[0] max_diff = np.abs(pt_spk - onnx_spk).max() cos_sim = float(np.dot(pt_spk.flatten(), onnx_spk.flatten()) / (np.linalg.norm(pt_spk) * np.linalg.norm(onnx_spk))) print(f" voice_encoder max_diff={max_diff:.6f} cosine={cos_sim:.6f} {'✓' if cos_sim > 0.999 else '✗ MISMATCH'}") if __name__ == "__main__": engine = load_engine() embed_path = export_embed_tokens(engine) ve_path = export_voice_encoder(engine) validate(engine, embed_path, ve_path) print("\nDone. Upload to RASMUS/Chatterbox-Finnish-ONNX:") print(f" huggingface-cli upload RASMUS/Chatterbox-Finnish-ONNX {OUT_DIR}/embed_tokens.onnx onnx/embed_tokens_finnish.onnx") print(f" huggingface-cli upload RASMUS/Chatterbox-Finnish-ONNX {OUT_DIR}/voice_encoder.onnx onnx/voice_encoder.onnx")