RASMUS commited on
Commit
e23a410
Β·
verified Β·
1 Parent(s): 87e8f1b

Add scripts/export_finnish_embeddings.py

Browse files
Files changed (1) hide show
  1. scripts/export_finnish_embeddings.py +180 -0
scripts/export_finnish_embeddings.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ export_finnish_embeddings.py
3
+
4
+ Exports two ONNX components from the Finnish fine-tuned model that are currently
5
+ missing from the browser pipeline:
6
+
7
+ 1. embed_tokens.onnx β€” Finnish T3's text_emb + position embeddings
8
+ (base version has slightly different weights)
9
+ 2. voice_encoder.onnx β€” Perth WavLM VoiceEncoder β†’ 256-dim speaker embedding
10
+ (enables custom reference audio in browser without precomputed cond_emb)
11
+
12
+ These two, combined with the already-uploaded finnish_cond_enc.onnx, give the
13
+ browser the full custom-voice pipeline:
14
+ voice_encoder β†’ speaker_emb β†’ cond_enc β†’ cond_emb β†’ language_model β†’ decoder
15
+
16
+ Outputs:
17
+ _onnx_export/embed_tokens.onnx (small, ~140 MB)
18
+ _onnx_export/voice_encoder.onnx (small, ~65 MB)
19
+
20
+ Usage:
21
+ cd /workspaces/work
22
+ conda run -n chatterbox-onnx python export_finnish_embeddings.py
23
+ """
24
+
25
+ import os, sys
26
+ import numpy as np
27
+ import torch
28
+ import onnx
29
+ from onnx.external_data_helper import convert_model_to_external_data
30
+ from pathlib import Path
31
+ from safetensors.torch import load_file
32
+
33
+ sys.path.insert(0, "Chatterbox-Finnish")
34
+
35
+ PRETRAINED_DIR = "Chatterbox-Finnish/pretrained_models"
36
+ FINETUNED_W = "Chatterbox-Finnish/models/best_finnish_multilingual_cp986.safetensors"
37
+ OUT_DIR = Path("_onnx_export"); OUT_DIR.mkdir(exist_ok=True)
38
+
39
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+
42
+ def load_engine():
43
+ from src.chatterbox_.tts import ChatterboxTTS
44
+ print(f" loading base engine ({DEVICE})...")
45
+ engine = ChatterboxTTS.from_local(PRETRAINED_DIR, device=DEVICE)
46
+ print(" injecting Finnish weights...")
47
+ ckpt = load_file(FINETUNED_W)
48
+ t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in ckpt.items()}
49
+ missing, unexpected = engine.t3.load_state_dict(t3_state, strict=False)
50
+ print(f" loaded: {len(t3_state)-len(missing)} keys, missing={len(missing)}, unexpected={len(unexpected)}")
51
+ return engine
52
+
53
+
54
+ # ── 1. embed_tokens.onnx ─────────────────────────────────────────────────────
55
+ def export_embed_tokens(engine):
56
+ """
57
+ Wraps T3's token embedding table.
58
+ Input: input_ids [batch, seq] int64
59
+ Output: embeds [batch, seq, 1024] float32
60
+
61
+ Note: T3 uses a single embedding table (text_emb) for both text tokens and
62
+ speech tokens. The base ONNX repo exports this the same way.
63
+ """
64
+ print("\n── export_embed_tokens ──")
65
+ out_path = str(OUT_DIR / "embed_tokens.onnx")
66
+
67
+ class EmbedTokens(torch.nn.Module):
68
+ def __init__(self, emb: torch.nn.Embedding):
69
+ super().__init__()
70
+ self.emb = emb
71
+
72
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
73
+ return self.emb(input_ids)
74
+
75
+ # T3's text_emb is the token embedding table
76
+ emb_module = EmbedTokens(engine.t3.text_emb).to(DEVICE).eval()
77
+ vocab_size = engine.t3.text_emb.weight.shape[0]
78
+ print(f" vocab_size={vocab_size}, embed_dim={engine.t3.text_emb.weight.shape[1]}")
79
+
80
+ dummy_ids = torch.zeros(1, 5, dtype=torch.long, device=DEVICE)
81
+
82
+ with torch.no_grad():
83
+ torch.onnx.export(
84
+ emb_module,
85
+ (dummy_ids,),
86
+ out_path,
87
+ input_names=["input_ids"],
88
+ output_names=["embeds"],
89
+ dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "embeds": {0: "batch", 1: "seq"}},
90
+ opset_version=17,
91
+ do_constant_folding=True,
92
+ )
93
+
94
+ # Validate
95
+ model = onnx.load(out_path)
96
+ onnx.checker.check_model(model)
97
+ size_mb = os.path.getsize(out_path) / 1e6
98
+ print(f" βœ“ {out_path} ({size_mb:.1f} MB)")
99
+ return out_path
100
+
101
+
102
+ # ── 2. voice_encoder.onnx ────────────────────────────────────────────────────
103
+ def export_voice_encoder(engine):
104
+ """
105
+ Wraps the Perth WavLM VoiceEncoder.
106
+ Input: audio [batch, samples] float32 (16kHz, variable length)
107
+ Output: speaker_emb [batch, 256] float32
108
+
109
+ This allows the browser to compute speaker embeddings from arbitrary
110
+ reference audio (instead of loading precomputed finnish_cond_emb.bin).
111
+ """
112
+ print("\n── export_voice_encoder ──")
113
+ out_path = str(OUT_DIR / "voice_encoder.onnx")
114
+
115
+ ve = engine.ve.to(DEVICE).eval()
116
+
117
+ # Perth VoiceEncoder takes raw audio at 16kHz
118
+ # Try with 3 seconds of audio to expose dynamic shapes
119
+ dummy_audio = torch.zeros(1, 48000, device=DEVICE) # 3s @ 16kHz
120
+
121
+ with torch.no_grad():
122
+ torch.onnx.export(
123
+ ve,
124
+ (dummy_audio,),
125
+ out_path,
126
+ input_names=["audio"],
127
+ output_names=["speaker_emb"],
128
+ dynamic_axes={"audio": {0: "batch", 1: "samples"}, "speaker_emb": {0: "batch"}},
129
+ opset_version=17,
130
+ do_constant_folding=True,
131
+ )
132
+
133
+ model = onnx.load(out_path)
134
+ onnx.checker.check_model(model)
135
+ size_mb = os.path.getsize(out_path) / 1e6
136
+ print(f" βœ“ {out_path} ({size_mb:.1f} MB)")
137
+ return out_path
138
+
139
+
140
+ # ── Validate both exports match PyTorch ──────────────────────────────────────
141
+ def validate(engine, embed_path: str, ve_path: str):
142
+ import onnxruntime as ort
143
+ import librosa
144
+
145
+ print("\n── Validation ──")
146
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
147
+
148
+ # Validate embed_tokens
149
+ sess_et = ort.InferenceSession(embed_path, providers=providers)
150
+ test_ids = np.array([[255, 284, 18, 22, 7, 0]], dtype=np.int64) # SOT + some tokens + EOT
151
+ with torch.no_grad():
152
+ pt_emb = engine.t3.text_emb(torch.tensor(test_ids, device=DEVICE)).cpu().numpy()
153
+ onnx_emb = sess_et.run(None, {"input_ids": test_ids})[0]
154
+ max_diff = np.abs(pt_emb - onnx_emb).max()
155
+ print(f" embed_tokens max_diff={max_diff:.6f} {'βœ“' if max_diff < 1e-4 else 'βœ— MISMATCH'}")
156
+
157
+ # Validate voice_encoder
158
+ ref_audio, ref_sr = librosa.load("Chatterbox-Finnish/samples/reference_finnish.wav", sr=None)
159
+ ref_16k = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=16000).astype(np.float32)
160
+ ref_input_np = ref_16k[np.newaxis, :]
161
+ ref_input_pt = torch.tensor(ref_input_np, device=DEVICE)
162
+
163
+ sess_ve = ort.InferenceSession(ve_path, providers=providers)
164
+ with torch.no_grad():
165
+ pt_spk = engine.ve(ref_input_pt).cpu().numpy()
166
+ onnx_spk = sess_ve.run(None, {"audio": ref_input_np})[0]
167
+ max_diff = np.abs(pt_spk - onnx_spk).max()
168
+ cos_sim = float(np.dot(pt_spk.flatten(), onnx_spk.flatten()) /
169
+ (np.linalg.norm(pt_spk) * np.linalg.norm(onnx_spk)))
170
+ print(f" voice_encoder max_diff={max_diff:.6f} cosine={cos_sim:.6f} {'βœ“' if cos_sim > 0.999 else 'βœ— MISMATCH'}")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ engine = load_engine()
175
+ embed_path = export_embed_tokens(engine)
176
+ ve_path = export_voice_encoder(engine)
177
+ validate(engine, embed_path, ve_path)
178
+ print("\nDone. Upload to RASMUS/Chatterbox-Finnish-ONNX:")
179
+ print(f" huggingface-cli upload RASMUS/Chatterbox-Finnish-ONNX {OUT_DIR}/embed_tokens.onnx onnx/embed_tokens_finnish.onnx")
180
+ print(f" huggingface-cli upload RASMUS/Chatterbox-Finnish-ONNX {OUT_DIR}/voice_encoder.onnx onnx/voice_encoder.onnx")