RASMUS commited on
Commit
a7c083e
·
verified ·
1 Parent(s): 24c8aaf

Upload scripts/compare_onnx_vs_pytorch_parity.py with huggingface_hub

Browse files
scripts/compare_onnx_vs_pytorch_parity.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ compare_onnx_vs_pytorch_parity.py
3
+
4
+ Parity-first Finnish Chatterbox comparison harness.
5
+
6
+ Purpose:
7
+ - Keep `Chatterbox-Finnish/inference_example.py` as the behavioral source of truth.
8
+ - Compare that PyTorch path against an explicit ONNX path that mirrors the future
9
+ browser worker logic as closely as possible.
10
+ - Avoid mixing debug shortcuts into the default comparison path.
11
+
12
+ Notes:
13
+ - This script requires CUDA for ONNX Runtime. If CUDA is unavailable, it fails
14
+ loudly instead of silently falling back to CPU.
15
+ - The ONNX path still uses precomputed Finnish conditioning because the Finnish
16
+ `cond_enc` path is not yet fully packaged for the browser runtime.
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import requests
26
+ import soundfile as sf
27
+
28
+
29
+ CONFIG = {
30
+ "MODE": "parity",
31
+ "TEXT": "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä.",
32
+ "REFERENCE_AUDIO": "Chatterbox-Finnish/samples/reference_finnish.wav",
33
+ "FINETUNED_WEIGHTS": "Chatterbox-Finnish/models/best_finnish_multilingual_cp986.safetensors",
34
+ "PRETRAINED_DIR": "Chatterbox-Finnish/pretrained_models",
35
+ "OUT_DIR": "_cmp",
36
+ "ONNX_CACHE_DIR": "_onnx_cache",
37
+ "HF_BASE_REPO": "onnx-community/chatterbox-multilingual-ONNX",
38
+ "HF_FINNISH_REPO": "RASMUS/Chatterbox-Finnish-ONNX",
39
+ "SEED": 42,
40
+ "REPETITION_PENALTY": 1.2,
41
+ "TEMPERATURE": 0.8,
42
+ "EXAGGERATION": 0.6,
43
+ "CFG_WEIGHT": 0.5,
44
+ "MIN_P": 0.05,
45
+ "MIN_SPEECH_TOKENS": 40,
46
+ "MAX_GENERATION_STEPS": 800,
47
+ "RUN_TRANSCRIPTION": True,
48
+ "RUN_ANALYZE_AUDIO": False,
49
+ }
50
+
51
+
52
+ OUT_DIR = Path(CONFIG["OUT_DIR"])
53
+ OUT_DIR.mkdir(exist_ok=True)
54
+ CACHE_DIR = Path(CONFIG["ONNX_CACHE_DIR"])
55
+ CACHE_DIR.mkdir(exist_ok=True)
56
+
57
+ START_SPEECH = 6561
58
+ STOP_SPEECH = 6562
59
+ SOT_TEXT = 255
60
+ EOT_TEXT = 0
61
+
62
+
63
+ def seed_everything(seed: int) -> None:
64
+ import torch
65
+
66
+ np.random.seed(seed)
67
+ torch.manual_seed(seed)
68
+ if torch.cuda.is_available():
69
+ torch.cuda.manual_seed_all(seed)
70
+
71
+
72
+ def hf_download(repo_id: str, filename: str) -> str:
73
+ from huggingface_hub import hf_hub_download
74
+
75
+ return hf_hub_download(
76
+ repo_id=repo_id,
77
+ filename=filename,
78
+ local_dir=str(CACHE_DIR),
79
+ local_dir_use_symlinks=False,
80
+ )
81
+
82
+
83
+ def require_cuda_providers():
84
+ import onnxruntime as ort
85
+
86
+ providers = ort.get_available_providers()
87
+ if "CUDAExecutionProvider" not in providers:
88
+ raise RuntimeError(
89
+ "CUDAExecutionProvider not available. "
90
+ "Set LD_LIBRARY_PATH to the conda env cuDNN path before running."
91
+ )
92
+ return ["CUDAExecutionProvider", "CPUExecutionProvider"]
93
+
94
+
95
+ def save_wav(arr: np.ndarray, path: str, sr: int) -> None:
96
+ sf.write(path, arr, sr)
97
+ duration = len(arr) / sr
98
+ peak = float(np.abs(arr).max())
99
+ print(f" saved -> {path} ({duration:.2f}s, peak={peak:.4f})")
100
+
101
+
102
+ def transcribe(wav_path: str, lang: str = "fi") -> str:
103
+ groq_key = os.environ.get("GROQ_API_KEY", "")
104
+ if not groq_key:
105
+ return "(no GROQ_API_KEY)"
106
+
107
+ with open(wav_path, "rb") as handle:
108
+ response = requests.post(
109
+ "https://api.groq.com/openai/v1/audio/transcriptions",
110
+ headers={"Authorization": f"Bearer {groq_key}"},
111
+ files={"file": (os.path.basename(wav_path), handle, "audio/wav")},
112
+ data={"model": "whisper-large-v3", "language": lang, "response_format": "text"},
113
+ timeout=300,
114
+ )
115
+ response.raise_for_status()
116
+ return response.text.strip()
117
+
118
+
119
+ def apply_rep_penalty(logits: np.ndarray, generated: list[int], penalty: float) -> np.ndarray:
120
+ updated = logits.copy()
121
+ for token in set(generated):
122
+ if updated[token] > 0:
123
+ updated[token] /= penalty
124
+ else:
125
+ updated[token] *= penalty
126
+ return updated
127
+
128
+
129
+ def apply_min_p(logits: np.ndarray, min_p: float) -> np.ndarray:
130
+ updated = logits.copy()
131
+ probs = np.exp(updated - updated.max())
132
+ probs /= probs.sum()
133
+ updated[probs < probs.max() * min_p] = -1e9
134
+ return updated
135
+
136
+
137
+ def sample_with_temperature(logits: np.ndarray, temperature: float) -> int:
138
+ scaled = logits / temperature
139
+ scaled -= scaled.max()
140
+ probs = np.exp(scaled)
141
+ probs /= probs.sum()
142
+ return int(np.random.choice(len(probs), p=probs))
143
+
144
+
145
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
146
+ flat_a = np.asarray(a).reshape(-1)
147
+ flat_b = np.asarray(b).reshape(-1)
148
+ denom = np.linalg.norm(flat_a) * np.linalg.norm(flat_b) + 1e-12
149
+ return float(np.dot(flat_a, flat_b) / denom)
150
+
151
+
152
+ def run_pytorch() -> str:
153
+ print("\n" + "=" * 64)
154
+ print("1. PYTORCH INFERENCE")
155
+ print("=" * 64)
156
+
157
+ import torch
158
+ from safetensors.torch import load_file
159
+
160
+ sys.path.insert(0, "Chatterbox-Finnish")
161
+ from src.chatterbox_.tts import ChatterboxTTS
162
+
163
+ device = "cuda" if torch.cuda.is_available() else "cpu"
164
+ engine = ChatterboxTTS.from_local(CONFIG["PRETRAINED_DIR"], device=device)
165
+ checkpoint = load_file(CONFIG["FINETUNED_WEIGHTS"])
166
+ t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint.items()}
167
+ engine.t3.load_state_dict(t3_state, strict=False)
168
+
169
+ start = time.time()
170
+ wav = engine.generate(
171
+ text=CONFIG["TEXT"],
172
+ audio_prompt_path=CONFIG["REFERENCE_AUDIO"],
173
+ repetition_penalty=CONFIG["REPETITION_PENALTY"],
174
+ temperature=CONFIG["TEMPERATURE"],
175
+ exaggeration=CONFIG["EXAGGERATION"],
176
+ cfg_weight=CONFIG["CFG_WEIGHT"],
177
+ min_p=CONFIG["MIN_P"],
178
+ )
179
+ elapsed = time.time() - start
180
+
181
+ array = wav.squeeze().cpu().numpy()
182
+ output_path = str(OUT_DIR / "pytorch_output.wav")
183
+ save_wav(array, output_path, engine.sr)
184
+ print(f" inference time: {elapsed:.1f}s")
185
+ return output_path
186
+
187
+
188
+ def run_onnx() -> str:
189
+ print("\n" + "=" * 64)
190
+ print("2. ONNX INFERENCE")
191
+ print("=" * 64)
192
+
193
+ import librosa
194
+ import onnxruntime as ort
195
+
196
+ providers = require_cuda_providers()
197
+ options = ort.SessionOptions()
198
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
199
+
200
+ speech_encoder_path = hf_download(CONFIG["HF_BASE_REPO"], "onnx/speech_encoder.onnx")
201
+ hf_download(CONFIG["HF_BASE_REPO"], "onnx/speech_encoder.onnx_data")
202
+ embed_tokens_path = hf_download(CONFIG["HF_BASE_REPO"], "onnx/embed_tokens.onnx")
203
+ hf_download(CONFIG["HF_BASE_REPO"], "onnx/embed_tokens.onnx_data")
204
+ conditional_decoder_path = hf_download(CONFIG["HF_BASE_REPO"], "onnx/conditional_decoder.onnx")
205
+ hf_download(CONFIG["HF_BASE_REPO"], "onnx/conditional_decoder.onnx_data")
206
+ language_model_path = hf_download(CONFIG["HF_FINNISH_REPO"], "onnx/language_model.onnx")
207
+ hf_download(CONFIG["HF_FINNISH_REPO"], "onnx/language_model.onnx_data")
208
+
209
+ print(" loading sessions...")
210
+ sess_se = ort.InferenceSession(speech_encoder_path, sess_options=options, providers=providers)
211
+ sess_et = ort.InferenceSession(embed_tokens_path, sess_options=options, providers=providers)
212
+ sess_lm = ort.InferenceSession(language_model_path, sess_options=options, providers=providers)
213
+ sess_cd = ort.InferenceSession(conditional_decoder_path, sess_options=options, providers=providers)
214
+
215
+ cond_emb_path = hf_download(CONFIG["HF_FINNISH_REPO"], "onnx/finnish_cond_emb.bin")
216
+ with open(cond_emb_path, "rb") as handle:
217
+ cond_emb = np.frombuffer(handle.read(), dtype=np.float32).reshape(1, 34, 1024)
218
+ print(f" cond_emb: {cond_emb.shape}")
219
+
220
+ sys.path.insert(0, "Chatterbox-Finnish")
221
+ from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer
222
+ from src.chatterbox_.tts import punc_norm
223
+
224
+ tokenizer = EnTokenizer(os.path.join(CONFIG["PRETRAINED_DIR"], "tokenizer.json"))
225
+ normalized_text = punc_norm(CONFIG["TEXT"])
226
+ token_ids = tokenizer.encode(normalized_text)
227
+ text_ids = np.array([[SOT_TEXT] + token_ids + [EOT_TEXT]], dtype=np.int64)
228
+ print(f" text tokens: {text_ids.shape}")
229
+
230
+ # Use 24kHz reference audio for speech_encoder, matching the official ONNX pipeline.
231
+ ref_24k, _ = librosa.load(CONFIG["REFERENCE_AUDIO"], sr=24000)
232
+ se_out = sess_se.run(None, {"audio_values": ref_24k[np.newaxis, :].astype(np.float32)})
233
+ prompt_tokens = se_out[1]
234
+ speaker_embeddings = se_out[2]
235
+ speaker_features = se_out[3]
236
+ print(f" prompt_tokens: {prompt_tokens.shape}")
237
+
238
+ exaggeration = np.array([CONFIG["EXAGGERATION"]], dtype=np.float32)
239
+ text_pos = np.arange(text_ids.shape[1], dtype=np.int64)[np.newaxis, :]
240
+ text_embeds = sess_et.run(
241
+ None,
242
+ {"input_ids": text_ids, "position_ids": text_pos, "exaggeration": exaggeration},
243
+ )[0]
244
+
245
+ bos_emb = sess_et.run(
246
+ None,
247
+ {
248
+ "input_ids": np.array([[START_SPEECH]], dtype=np.int64),
249
+ "position_ids": np.array([[0]], dtype=np.int64),
250
+ "exaggeration": exaggeration,
251
+ },
252
+ )[0]
253
+
254
+ prefill_cond = np.concatenate([cond_emb, text_embeds, bos_emb], axis=1)
255
+ prefill_uncond = np.concatenate([cond_emb, np.zeros_like(text_embeds), bos_emb], axis=1)
256
+
257
+ kv_meta = next(inp for inp in sess_lm.get_inputs() if inp.name == "past_key_values.0.key")
258
+ kv_dtype = np.float16 if "float16" in kv_meta.type else np.float32
259
+ kv_empty = np.zeros((1, 16, 0, 64), dtype=kv_dtype)
260
+ layer_count = 30
261
+ kv_cond = [(kv_empty.copy(), kv_empty.copy()) for _ in range(layer_count)]
262
+ kv_uncond = [(kv_empty.copy(), kv_empty.copy()) for _ in range(layer_count)]
263
+
264
+ def make_kv_feeds(kv_cache):
265
+ feeds = {}
266
+ for layer_index in range(layer_count):
267
+ feeds[f"past_key_values.{layer_index}.key"] = kv_cache[layer_index][0]
268
+ feeds[f"past_key_values.{layer_index}.value"] = kv_cache[layer_index][1]
269
+ return feeds
270
+
271
+ def lm_step(inputs_embeds: np.ndarray, attention_mask: np.ndarray, kv_cache):
272
+ feeds = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
273
+ feeds.update(make_kv_feeds(kv_cache))
274
+ outputs = sess_lm.run(None, feeds)
275
+ logits = outputs[0]
276
+ next_kv = [(outputs[1 + 2 * idx], outputs[2 + 2 * idx]) for idx in range(layer_count)]
277
+ return logits, next_kv
278
+
279
+ mask_cond = np.ones((1, prefill_cond.shape[1]), dtype=np.int64)
280
+ mask_uncond = np.ones((1, prefill_uncond.shape[1]), dtype=np.int64)
281
+ logits_cond, kv_cond = lm_step(prefill_cond, mask_cond, kv_cond)
282
+ logits_uncond, kv_uncond = lm_step(prefill_uncond, mask_uncond, kv_uncond)
283
+
284
+ generated = [START_SPEECH]
285
+ speech_tokens = []
286
+
287
+ def choose_next_token(logits_c: np.ndarray, logits_u: np.ndarray) -> int:
288
+ cond_last = logits_c[0, -1].astype(np.float32)
289
+ uncond_last = logits_u[0, -1].astype(np.float32)
290
+ merged = cond_last + CONFIG["CFG_WEIGHT"] * (cond_last - uncond_last)
291
+ merged = apply_rep_penalty(merged, generated, CONFIG["REPETITION_PENALTY"])
292
+ merged = apply_min_p(merged, CONFIG["MIN_P"])
293
+ return sample_with_temperature(merged, CONFIG["TEMPERATURE"])
294
+
295
+ first_token = choose_next_token(logits_cond, logits_uncond)
296
+ generated.append(first_token)
297
+ if first_token < START_SPEECH:
298
+ speech_tokens.append(first_token)
299
+
300
+ print(" generating...")
301
+ generation_start = time.time()
302
+ for step in range(1, CONFIG["MAX_GENERATION_STEPS"]):
303
+ step_emb = sess_et.run(
304
+ None,
305
+ {
306
+ "input_ids": np.array([[generated[-1]]], dtype=np.int64),
307
+ "position_ids": np.array([[step]], dtype=np.int64),
308
+ "exaggeration": exaggeration,
309
+ },
310
+ )[0]
311
+
312
+ step_mask_cond = np.ones((1, kv_cond[0][0].shape[2] + 1), dtype=np.int64)
313
+ step_mask_uncond = np.ones((1, kv_uncond[0][0].shape[2] + 1), dtype=np.int64)
314
+
315
+ logits_cond, kv_cond = lm_step(step_emb, step_mask_cond, kv_cond)
316
+ logits_uncond, kv_uncond = lm_step(step_emb, step_mask_uncond, kv_uncond)
317
+ next_token = choose_next_token(logits_cond, logits_uncond)
318
+
319
+ if next_token == STOP_SPEECH and len(speech_tokens) >= CONFIG["MIN_SPEECH_TOKENS"]:
320
+ print(f" EOS at step {step} ({len(speech_tokens)} speech tokens)")
321
+ break
322
+
323
+ generated.append(next_token)
324
+ if next_token < START_SPEECH:
325
+ speech_tokens.append(next_token)
326
+
327
+ if (step + 1) % 100 == 0:
328
+ elapsed = time.time() - generation_start
329
+ rate = (step + 1) / elapsed
330
+ print(f" step {step + 1}: {len(speech_tokens)} speech tokens ({rate:.1f} tok/s)")
331
+
332
+ print(f" generation time: {time.time() - generation_start:.1f}s")
333
+
334
+ generated_arr = np.array([speech_tokens], dtype=np.int64)
335
+ decoder_tokens = np.concatenate([prompt_tokens, generated_arr], axis=1)
336
+ print(f" decoder input: {decoder_tokens.shape}")
337
+ wav = sess_cd.run(
338
+ None,
339
+ {
340
+ "speech_tokens": decoder_tokens,
341
+ "speaker_embeddings": speaker_embeddings,
342
+ "speaker_features": speaker_features,
343
+ },
344
+ )[0].squeeze().astype(np.float32)
345
+
346
+ peak = float(np.abs(wav).max())
347
+ if peak < 0.01 and peak > 0:
348
+ wav = wav * (0.9 / peak)
349
+ wav = np.clip(wav, -1.0, 1.0)
350
+
351
+ output_path = str(OUT_DIR / "onnx_output_parity.wav")
352
+ save_wav(wav, output_path, 24000)
353
+ return output_path
354
+
355
+
356
+ def compare_outputs(pytorch_wav: str, onnx_wav: str) -> None:
357
+ print("\n" + "=" * 64)
358
+ print("3. OUTPUT COMPARISON")
359
+ print("=" * 64)
360
+
361
+ if CONFIG["RUN_TRANSCRIPTION"]:
362
+ print(f" ref text: {CONFIG['TEXT']}")
363
+ print(f" PyTorch: {transcribe(pytorch_wav)}")
364
+ print(f" ONNX: {transcribe(onnx_wav)}")
365
+
366
+ if CONFIG["RUN_ANALYZE_AUDIO"]:
367
+ import subprocess
368
+
369
+ subprocess.run(
370
+ [
371
+ sys.executable,
372
+ "analyze_audio.py",
373
+ pytorch_wav,
374
+ onnx_wav,
375
+ "--label-a",
376
+ "PyTorch",
377
+ "--label-b",
378
+ "ONNX",
379
+ "--ref-text",
380
+ CONFIG["TEXT"],
381
+ "--lang",
382
+ "fi",
383
+ ],
384
+ check=False,
385
+ )
386
+
387
+
388
+ def run_debug() -> None:
389
+ print("\n" + "=" * 64)
390
+ print("4. COMPONENT DEBUG")
391
+ print("=" * 64)
392
+
393
+ import librosa
394
+ import onnxruntime as ort
395
+ import torch
396
+ from safetensors.torch import load_file
397
+
398
+ providers = require_cuda_providers()
399
+ options = ort.SessionOptions()
400
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
401
+
402
+ sys.path.insert(0, "Chatterbox-Finnish")
403
+ from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer
404
+ from src.chatterbox_.tts import ChatterboxTTS, punc_norm
405
+
406
+ checkpoint = load_file(CONFIG["FINETUNED_WEIGHTS"])
407
+ t3_state = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint.items()}
408
+
409
+ engine = ChatterboxTTS.from_local(CONFIG["PRETRAINED_DIR"], device="cuda")
410
+ engine.t3.load_state_dict(t3_state, strict=False)
411
+
412
+ speech_encoder_path = hf_download(CONFIG["HF_BASE_REPO"], "onnx/speech_encoder.onnx")
413
+ hf_download(CONFIG["HF_BASE_REPO"], "onnx/speech_encoder.onnx_data")
414
+ embed_tokens_path = hf_download(CONFIG["HF_BASE_REPO"], "onnx/embed_tokens.onnx")
415
+ hf_download(CONFIG["HF_BASE_REPO"], "onnx/embed_tokens.onnx_data")
416
+
417
+ sess_se = ort.InferenceSession(speech_encoder_path, sess_options=options, providers=providers)
418
+ sess_et = ort.InferenceSession(embed_tokens_path, sess_options=options, providers=providers)
419
+
420
+ ref_24k, _ = librosa.load(CONFIG["REFERENCE_AUDIO"], sr=24000)
421
+ se_out = sess_se.run(None, {"audio_values": ref_24k[np.newaxis, :].astype(np.float32)})
422
+ onnx_prompt_tokens = se_out[1]
423
+ onnx_speaker_embeddings = se_out[2]
424
+ onnx_speaker_features = se_out[3]
425
+
426
+ with torch.no_grad():
427
+ engine.prepare_conditionals(CONFIG["REFERENCE_AUDIO"], exaggeration=CONFIG["EXAGGERATION"])
428
+ pytorch_ref = engine.conds.gen
429
+ pt_prompt_tokens = pytorch_ref["prompt_token"].cpu().numpy()
430
+ pt_speaker_features = pytorch_ref["prompt_feat"].cpu().numpy()
431
+ pt_speaker_embeddings = pytorch_ref["embedding"].cpu().numpy()
432
+
433
+ print(" speech_encoder vs PyTorch conditionals")
434
+ print(f" prompt_tokens exact match: {np.array_equal(onnx_prompt_tokens, pt_prompt_tokens)}")
435
+ print(f" speaker_embeddings cosine: {cosine_similarity(onnx_speaker_embeddings, pt_speaker_embeddings):.6f}")
436
+ print(f" onnx speaker_features: {onnx_speaker_features.shape}")
437
+ print(f" pytorch speaker_features: {pt_speaker_features.shape}")
438
+ if onnx_speaker_features.shape == pt_speaker_features.shape:
439
+ max_diff = float(np.abs(onnx_speaker_features - pt_speaker_features).max())
440
+ print(f" speaker_features max diff: {max_diff:.6f}")
441
+
442
+ tokenizer = EnTokenizer(os.path.join(CONFIG["PRETRAINED_DIR"], "tokenizer.json"))
443
+ normalized_text = punc_norm(CONFIG["TEXT"])
444
+ onnx_ids = [SOT_TEXT] + tokenizer.encode(normalized_text) + [EOT_TEXT]
445
+
446
+ with torch.no_grad():
447
+ pt_ids = engine.tokenizer.text_to_tokens(normalized_text)[0].tolist()
448
+ pt_ids = [engine.t3.hp.start_text_token] + pt_ids + [engine.t3.hp.stop_text_token]
449
+
450
+ print(" text tokenization")
451
+ print(f" exact match: {onnx_ids == pt_ids}")
452
+ print(f" onnx ids head: {onnx_ids[:8]}")
453
+ print(f" pytorch ids head: {pt_ids[:8]}")
454
+
455
+ exaggeration = np.array([CONFIG["EXAGGERATION"]], dtype=np.float32)
456
+ onnx_ids_arr = np.array([onnx_ids], dtype=np.int64)
457
+ pos_ids = np.arange(len(onnx_ids), dtype=np.int64)[np.newaxis, :]
458
+ onnx_embeds = sess_et.run(
459
+ None,
460
+ {"input_ids": onnx_ids_arr, "position_ids": pos_ids, "exaggeration": exaggeration},
461
+ )[0]
462
+
463
+ with torch.no_grad():
464
+ pt_embeds = engine.t3.text_emb(torch.tensor(onnx_ids_arr, device="cuda")).cpu().numpy()
465
+
466
+ print(" embed_tokens")
467
+ print(f" cosine: {cosine_similarity(onnx_embeds, pt_embeds):.6f}")
468
+ print(f" max diff: {float(np.abs(onnx_embeds - pt_embeds).max()):.6f}")
469
+
470
+
471
+ def main() -> None:
472
+ seed_everything(CONFIG["SEED"])
473
+ if CONFIG["MODE"] == "debug":
474
+ run_debug()
475
+ print("\nDone. Debug checks completed.")
476
+ return
477
+
478
+ pytorch_wav = run_pytorch()
479
+ onnx_wav = run_onnx()
480
+ compare_outputs(pytorch_wav, onnx_wav)
481
+ print(f"\nDone. Outputs saved in {OUT_DIR}/")
482
+
483
+
484
+ if __name__ == "__main__":
485
+ main()