Chatterbox-Finnish-ONNX / scripts /export_finnish_embeddings.py
RASMUS's picture
Add scripts/export_finnish_embeddings.py
e23a410 verified
"""
export_finnish_embeddings.py
Exports two ONNX components from the Finnish fine-tuned model that are currently
missing from the browser pipeline:
1. embed_tokens.onnx β€” Finnish T3's text_emb + position embeddings
(base version has slightly different weights)
2. voice_encoder.onnx β€” Perth WavLM VoiceEncoder β†’ 256-dim speaker embedding
(enables custom reference audio in browser without precomputed cond_emb)
These two, combined with the already-uploaded finnish_cond_enc.onnx, give the
browser the full custom-voice pipeline:
voice_encoder β†’ speaker_emb β†’ cond_enc β†’ cond_emb β†’ language_model β†’ decoder
Outputs:
_onnx_export/embed_tokens.onnx (small, ~140 MB)
_onnx_export/voice_encoder.onnx (small, ~65 MB)
Usage:
cd /workspaces/work
conda run -n chatterbox-onnx python export_finnish_embeddings.py
"""
import os, sys
import numpy as np
import torch
import onnx
from onnx.external_data_helper import convert_model_to_external_data
from pathlib import Path
from safetensors.torch import load_file
sys.path.insert(0, "Chatterbox-Finnish")
PRETRAINED_DIR = "Chatterbox-Finnish/pretrained_models"
FINETUNED_W = "Chatterbox-Finnish/models/best_finnish_multilingual_cp986.safetensors"
OUT_DIR = Path("_onnx_export"); OUT_DIR.mkdir(exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_engine():
from src.chatterbox_.tts import ChatterboxTTS
print(f" loading base engine ({DEVICE})...")
engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=DEVICE)
print(" injecting Finnish weights...")
ckpt = load_file(FINETUNED_W)
t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()}
missing, unexpected = engine.t3.load_state_dict(t3_state, strict=False)
print(f" loaded: {len(t3_state)-len(missing)} keys, missing={len(missing)}, unexpected={len(unexpected)}")
return engine
# ── 1. embed_tokens.onnx ─────────────────────────────────────────────────────
def export_embed_tokens(engine):
"""
Wraps T3's token embedding table.
Input: input_ids [batch, seq] int64
Output: embeds [batch, seq, 1024] float32
Note: T3 uses a single embedding table (text_emb) for both text tokens and
speech tokens. The base ONNX repo exports this the same way.
"""
print("\n── export_embed_tokens ──")
out_path = str(OUT_DIR / "embed_tokens.onnx")
class EmbedTokens(torch.nn.Module):
def __init__(self, emb: torch.nn.Embedding):
super().__init__()
self.emb = emb
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.emb(input_ids)
# T3's text_emb is the token embedding table
emb_module = EmbedTokens(engine.t3.text_emb).to(DEVICE).eval()
vocab_size = engine.t3.text_emb.weight.shape[0]
print(f" vocab_size={vocab_size}, embed_dim={engine.t3.text_emb.weight.shape[1]}")
dummy_ids = torch.zeros(1, 5, dtype=torch.long, device=DEVICE)
with torch.no_grad():
torch.onnx.export(
emb_module,
(dummy_ids,),
out_path,
input_names=["input_ids"],
output_names=["embeds"],
dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "embeds": {0: "batch", 1: "seq"}},
opset_version=17,
do_constant_folding=True,
)
# Validate
model = onnx.load(out_path)
onnx.checker.check_model(model)
size_mb = os.path.getsize(out_path) / 1e6
print(f" βœ“ {out_path} ({size_mb:.1f} MB)")
return out_path
# ── 2. voice_encoder.onnx ────────────────────────────────────────────────────
def export_voice_encoder(engine):
"""
Wraps the Perth WavLM VoiceEncoder.
Input: audio [batch, samples] float32 (16kHz, variable length)
Output: speaker_emb [batch, 256] float32
This allows the browser to compute speaker embeddings from arbitrary
reference audio (instead of loading precomputed finnish_cond_emb.bin).
"""
print("\n── export_voice_encoder ──")
out_path = str(OUT_DIR / "voice_encoder.onnx")
ve = engine.ve.to(DEVICE).eval()
# Perth VoiceEncoder takes raw audio at 16kHz
# Try with 3 seconds of audio to expose dynamic shapes
dummy_audio = torch.zeros(1, 48000, device=DEVICE) # 3s @ 16kHz
with torch.no_grad():
torch.onnx.export(
ve,
(dummy_audio,),
out_path,
input_names=["audio"],
output_names=["speaker_emb"],
dynamic_axes={"audio": {0: "batch", 1: "samples"}, "speaker_emb": {0: "batch"}},
opset_version=17,
do_constant_folding=True,
)
model = onnx.load(out_path)
onnx.checker.check_model(model)
size_mb = os.path.getsize(out_path) / 1e6
print(f" βœ“ {out_path} ({size_mb:.1f} MB)")
return out_path
# ── Validate both exports match PyTorch ──────────────────────────────────────
def validate(engine, embed_path: str, ve_path: str):
import onnxruntime as ort
import librosa
print("\n── Validation ──")
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
# Validate embed_tokens
sess_et = ort.InferenceSession(embed_path, providers=providers)
test_ids = np.array([[255, 284, 18, 22, 7, 0]], dtype=np.int64) # SOT + some tokens + EOT
with torch.no_grad():
pt_emb = engine.t3.text_emb(torch.tensor(test_ids, device=DEVICE)).cpu().numpy()
onnx_emb = sess_et.run(None, {"input_ids": test_ids})[0]
max_diff = np.abs(pt_emb - onnx_emb).max()
print(f" embed_tokens max_diff={max_diff:.6f} {'βœ“' if max_diff < 1e-4 else 'βœ— MISMATCH'}")
# Validate voice_encoder
ref_audio, ref_sr = librosa.load("Chatterbox-Finnish/samples/reference_finnish.wav", sr=None)
ref_16k = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=16000).astype(np.float32)
ref_input_np = ref_16k[np.newaxis, :]
ref_input_pt = torch.tensor(ref_input_np, device=DEVICE)
sess_ve = ort.InferenceSession(ve_path, providers=providers)
with torch.no_grad():
pt_spk = engine.ve(ref_input_pt).cpu().numpy()
onnx_spk = sess_ve.run(None, {"audio": ref_input_np})[0]
max_diff = np.abs(pt_spk - onnx_spk).max()
cos_sim = float(np.dot(pt_spk.flatten(), onnx_spk.flatten()) /
(np.linalg.norm(pt_spk) * np.linalg.norm(onnx_spk)))
print(f" voice_encoder max_diff={max_diff:.6f} cosine={cos_sim:.6f} {'βœ“' if cos_sim > 0.999 else 'βœ— MISMATCH'}")
if __name__ == "__main__":
engine = load_engine()
embed_path = export_embed_tokens(engine)
ve_path = export_voice_encoder(engine)
validate(engine, embed_path, ve_path)
print("\nDone. Upload to RASMUS/Chatterbox-Finnish-ONNX:")
print(f" huggingface-cli upload RASMUS/Chatterbox-Finnish-ONNX {OUT_DIR}/embed_tokens.onnx onnx/embed_tokens_finnish.onnx")
print(f" huggingface-cli upload RASMUS/Chatterbox-Finnish-ONNX {OUT_DIR}/voice_encoder.onnx onnx/voice_encoder.onnx")