Prince-1 commited on
Commit
ae1c6df
·
verified ·
1 Parent(s): e50f2d1

Add files using upload-large-folder tool

Browse files
Files changed (1) hide show
  1. convert_to_onnx.py +734 -0
convert_to_onnx.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ convert_to_onnx.py - Export VibeVoice ASR components to ONNX opset 20.
4
+
5
+ Exports (written to --output-dir, default: onnx_outputs/):
6
+ acoustic_encoder.onnx audio [B,1,T] -> acoustic_latent [B,F,64]
7
+ acoustic_decoder.onnx acoustic_latent [B,64,F] -> audio [B,1,T]
8
+ semantic_encoder.onnx audio [B,1,T] -> semantic_latent [B,F,128]
9
+ acoustic_connector.onnx acoustic_latent [B,F,64] -> lm_features [B,F,3584]
10
+ semantic_connector.onnx semantic_latent [B,F,128] -> lm_features [B,F,3584]
11
+ diffusion_head.onnx (noisy[N,L], timesteps[N], condition[N,H]) -> predicted[N,L]
12
+ llm_embed_tokens.onnx token_ids [B,T] -> embeddings [B,T,3584]
13
+ lm_head.onnx hidden_states [B,T,3584] -> logits [B,T,152064]
14
+
15
+ Architecture facts (from content/ configs):
16
+ Encoder ratios (applied order) : 2, 2, 4, 5, 5, 8 (reversed from config [8,5,5,4,2,2])
17
+ Total hop length : 2*2*4*5*5*8 = 1600 samples (~66.7 ms at 24 kHz)
18
+ Acoustic VAE dim : 64
19
+ Semantic VAE dim : 128
20
+ LM hidden size (Qwen2.5-7B) : 3584
21
+ Vocab size : 152 064
22
+
23
+ Reference input size (REF_AUDIO_LEN = 48 000 samples = 2 s at 24 kHz):
24
+ This length gives an exact integer frame count at EVERY downsampling stage,
25
+ so no extra padding is baked into the ONNX graph as a constant.
26
+ For variable-length inference pad audio to multiples of REF_AUDIO_LEN, OR
27
+ use --dynamo to export with fully dynamic shapes.
28
+
29
+ Usage:
30
+ python convert_to_onnx.py
31
+ python convert_to_onnx.py --output-dir onnx_out --device cpu
32
+ python convert_to_onnx.py --skip-llm # skip 7 B LLM (saves ~30 GB RAM)
33
+ python convert_to_onnx.py --dynamo # use torch.onnx.dynamo_export
34
+ python convert_to_onnx.py --components acoustic_encoder acoustic_connector
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import sys
40
+ import os
41
+ import logging
42
+ import argparse
43
+ import warnings
44
+ from pathlib import Path
45
+ from typing import Dict, List, Optional, Tuple
46
+
47
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
48
+ warnings.filterwarnings("ignore", category=FutureWarning)
49
+ warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
50
+
51
+ import torch
52
+ import torch.nn as nn
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Paths
56
+ # ---------------------------------------------------------------------------
57
+ ROOT = Path(__file__).parent.resolve()
58
+ CONTENT = ROOT / "content"
59
+ VIBEVOICE_SRC = ROOT / "VibeVoice"
60
+
61
+ if str(VIBEVOICE_SRC) not in sys.path:
62
+ sys.path.insert(0, str(VIBEVOICE_SRC))
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Constants
66
+ # ---------------------------------------------------------------------------
67
+ OPSET = 20
68
+ SAMPLE_RATE = 24_000 # Hz - fixed by the VibeVoice architecture
69
+ HOP_LENGTH = 1600 # 2*2*4*5*5*8 - total encoder downsampling factor
70
+
71
+ # 48 000 samples = 2 s at 24 kHz. This is the smallest T where every
72
+ # downsampling stage (strides 2,2,4,5,5,8) produces an exact integer
73
+ # frame count, so extra_padding=0 everywhere and the ONNX graph has
74
+ # no baked-in padding constants.
75
+ REF_AUDIO_LEN = 48_000
76
+
77
+ ACOUSTIC_VAE_DIM = 64
78
+ SEMANTIC_VAE_DIM = 128
79
+ LM_HIDDEN = 3584
80
+ LM_VOCAB = 152_064
81
+
82
+ logging.basicConfig(
83
+ level=logging.INFO,
84
+ format="%(asctime)s %(levelname)-7s %(message)s",
85
+ datefmt="%H:%M:%S",
86
+ )
87
+ log = logging.getLogger(__name__)
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # 1. Register VibeVoice custom classes with Transformers AutoModel
92
+ # ---------------------------------------------------------------------------
93
+
94
+ def _register_vibevoice():
95
+ """Import VibeVoice classes and register with Transformers AutoModel."""
96
+ from vibevoice.modular.configuration_vibevoice import (
97
+ VibeVoiceAcousticTokenizerConfig,
98
+ VibeVoiceSemanticTokenizerConfig,
99
+ VibeVoiceDiffusionHeadConfig,
100
+ )
101
+ from vibevoice.modular.modular_vibevoice_tokenizer import (
102
+ VibeVoiceAcousticTokenizerModel,
103
+ VibeVoiceSemanticTokenizerModel,
104
+ )
105
+ from vibevoice.modular.modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
106
+ from transformers.models.auto import AutoModel
107
+
108
+ for cfg, mdl in [
109
+ (VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel),
110
+ (VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel),
111
+ (VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead),
112
+ ]:
113
+ try:
114
+ AutoModel.register(cfg, mdl)
115
+ except Exception:
116
+ pass # already registered - fine
117
+
118
+ log.info("VibeVoice model classes registered with AutoModel")
119
+ return (
120
+ VibeVoiceAcousticTokenizerConfig,
121
+ VibeVoiceSemanticTokenizerConfig,
122
+ VibeVoiceDiffusionHeadConfig,
123
+ VibeVoiceAcousticTokenizerModel,
124
+ VibeVoiceSemanticTokenizerModel,
125
+ VibeVoiceDiffusionHead,
126
+ )
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # 2. ONNX-friendly wrapper modules
131
+ # ---------------------------------------------------------------------------
132
+
133
+ class AcousticEncoderONNX(nn.Module):
134
+ """Acoustic tokenizer encoder: audio [B,1,T] -> latent_mean [B,F,64].
135
+
136
+ Calls the encoder in non-streaming mode (use_cache=False) and returns
137
+ only the mean latent (no stochastic sampling).
138
+ """
139
+ def __init__(self, tokenizer: nn.Module):
140
+ super().__init__()
141
+ self.encoder = tokenizer.encoder
142
+
143
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
144
+ # audio: [B, 1, T] -> latents: [B, vae_dim, F] -> [B, F, vae_dim]
145
+ latents = self.encoder(audio) # [B, 64, F]
146
+ return latents.permute(0, 2, 1) # [B, F, 64]
147
+
148
+
149
+ class AcousticDecoderONNX(nn.Module):
150
+ """Acoustic tokenizer decoder: latent [B,64,F] -> audio [B,1,T]."""
151
+ def __init__(self, tokenizer: nn.Module):
152
+ super().__init__()
153
+ self.decoder = tokenizer.decoder
154
+ self.vae_dim = tokenizer.config.vae_dim
155
+
156
+ def forward(self, latents: torch.Tensor) -> torch.Tensor:
157
+ # Accept both [B, 64, F] and [B, F, 64]
158
+ if latents.shape[1] != self.vae_dim:
159
+ latents = latents.permute(0, 2, 1) # [B, 64, F]
160
+ return self.decoder(latents) # [B, 1, T]
161
+
162
+
163
+ class SemanticEncoderONNX(nn.Module):
164
+ """Semantic tokenizer encoder: audio [B,1,T] -> latent_mean [B,F,128]."""
165
+ def __init__(self, tokenizer: nn.Module):
166
+ super().__init__()
167
+ self.encoder = tokenizer.encoder
168
+
169
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
170
+ latents = self.encoder(audio) # [B, 128, F]
171
+ return latents.permute(0, 2, 1) # [B, F, 128]
172
+
173
+
174
+ class SpeechConnectorONNX(nn.Module):
175
+ """Thin wrapper around SpeechConnector (Linear -> RMSNorm -> Linear)."""
176
+ def __init__(self, connector: nn.Module):
177
+ super().__init__()
178
+ self.connector = connector
179
+
180
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
181
+ return self.connector(features)
182
+
183
+
184
+ class DiffusionHeadONNX(nn.Module):
185
+ """VibeVoiceDiffusionHead wrapper with explicit positional inputs."""
186
+ def __init__(self, head: nn.Module):
187
+ super().__init__()
188
+ self.head = head
189
+
190
+ def forward(
191
+ self,
192
+ noisy_latent: torch.Tensor, # [N, latent_size]
193
+ timesteps: torch.Tensor, # [N] float
194
+ condition: torch.Tensor, # [N, hidden_size]
195
+ ) -> torch.Tensor:
196
+ return self.head(noisy_latent, timesteps, condition)
197
+
198
+
199
+ class LLMEmbedTokensONNX(nn.Module):
200
+ """Token embedding table: input_ids [B,T] -> embeddings [B,T,H]."""
201
+ def __init__(self, embed_tokens: nn.Module):
202
+ super().__init__()
203
+ self.embed_tokens = embed_tokens
204
+
205
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
206
+ return self.embed_tokens(input_ids)
207
+
208
+
209
+ class LMHeadONNX(nn.Module):
210
+ """LM head linear: hidden_states [B,T,H] -> logits [B,T,V]."""
211
+ def __init__(self, lm_head: nn.Module):
212
+ super().__init__()
213
+ self.lm_head = lm_head
214
+
215
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216
+ return self.lm_head(hidden_states)
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # 3. Core export helper
221
+ # ---------------------------------------------------------------------------
222
+
223
+ def _export_onnx(
224
+ model: nn.Module,
225
+ sample_args: tuple,
226
+ out_path: Path,
227
+ input_names: List[str],
228
+ output_names: List[str],
229
+ dynamic_axes: Optional[Dict] = None,
230
+ use_dynamo: bool = False,
231
+ ) -> None:
232
+ """Export *model* to ONNX opset 20 at *out_path*."""
233
+ import onnx
234
+
235
+ model.eval()
236
+ with torch.no_grad():
237
+ if use_dynamo:
238
+ _export_dynamo(model, sample_args, out_path, input_names, output_names)
239
+ else:
240
+ _export_traditional(
241
+ model, sample_args, out_path,
242
+ input_names, output_names, dynamic_axes or {},
243
+ )
244
+
245
+ # Validate the model
246
+ proto = onnx.load(str(out_path))
247
+ onnx.checker.check_model(proto)
248
+ size_mb = out_path.stat().st_size / 1e6
249
+ log.info(" [OK] %-38s %.1f MB", out_path.name, size_mb)
250
+
251
+
252
+ def _export_traditional(
253
+ model, sample_args, out_path, input_names, output_names, dynamic_axes
254
+ ):
255
+ """Old-style torch.onnx.export (universally supported)."""
256
+ with torch.no_grad():
257
+ torch.onnx.export(
258
+ model,
259
+ sample_args,
260
+ str(out_path),
261
+ opset_version=OPSET,
262
+ input_names=input_names,
263
+ output_names=output_names,
264
+ dynamic_axes=dynamic_axes,
265
+ do_constant_folding=True,
266
+ export_params=True,
267
+ )
268
+
269
+
270
+ def _export_dynamo(model, sample_args, out_path, input_names, output_names):
271
+ """torch.onnx.dynamo_export - dynamic shapes, no baked-in constants."""
272
+ pt_ver = tuple(int(x) for x in torch.__version__.split(".")[:2] if x.isdigit())
273
+
274
+ if pt_ver >= (2, 6):
275
+ # Unified API (PyTorch ≥ 2.6)
276
+ torch.onnx.export(
277
+ model,
278
+ sample_args,
279
+ str(out_path),
280
+ dynamo=True,
281
+ opset_version=OPSET,
282
+ input_names=input_names,
283
+ output_names=output_names,
284
+ )
285
+ elif pt_ver >= (2, 1):
286
+ # Legacy dynamo API (PyTorch 2.1 – 2.5)
287
+ export_opts = torch.onnx.ExportOptions(opset_version=OPSET)
288
+ prog = torch.onnx.dynamo_export(
289
+ model, *sample_args, export_options=export_opts
290
+ )
291
+ prog.save(str(out_path))
292
+ else:
293
+ raise RuntimeError(
294
+ f"--dynamo requires PyTorch >= 2.1; found {torch.__version__}"
295
+ )
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # 4. Model loading helpers
300
+ # ---------------------------------------------------------------------------
301
+
302
+ def _load_pth_state(path: Path) -> Dict:
303
+ """Load a .pth file and unwrap common wrapper dicts."""
304
+ sd = torch.load(str(path), map_location="cpu", weights_only=False)
305
+ for wrap_key in ("state_dict", "model", "model_state_dict"):
306
+ if isinstance(sd, dict) and wrap_key in sd and isinstance(sd[wrap_key], dict):
307
+ sd = sd[wrap_key]
308
+ break
309
+ return sd
310
+
311
+
312
+ def _strip_prefix(sd: Dict, prefix: str) -> Dict:
313
+ return {
314
+ (k[len(prefix):] if k.startswith(prefix) else k): v
315
+ for k, v in sd.items()
316
+ }
317
+
318
+
319
+ def _load_acoustic_tokenizer(device: torch.device):
320
+ from transformers import AutoModel
321
+ model = AutoModel.from_pretrained(
322
+ str(CONTENT / "acoustic"),
323
+ trust_remote_code=True,
324
+ torch_dtype=torch.float32,
325
+ ).to(device).eval()
326
+ log.info(" Acoustic tokenizer loaded (VAE dim=%d)", model.config.vae_dim)
327
+ return model
328
+
329
+
330
+ def _load_semantic_tokenizer(device: torch.device):
331
+ from transformers import AutoModel
332
+ model = AutoModel.from_pretrained(
333
+ str(CONTENT / "semantic"),
334
+ trust_remote_code=True,
335
+ torch_dtype=torch.float32,
336
+ ).to(device).eval()
337
+ log.info(" Semantic tokenizer loaded (VAE dim=%d)", model.config.vae_dim)
338
+ return model
339
+
340
+
341
+ def _load_connector(
342
+ path: Path,
343
+ input_dim: int,
344
+ output_dim: int,
345
+ device: torch.device,
346
+ ) -> nn.Module:
347
+ from vibevoice.modular.modeling_vibevoice import SpeechConnector
348
+
349
+ connector = SpeechConnector(input_dim, output_dim).to(device)
350
+ sd = _load_pth_state(path)
351
+
352
+ # Strip common prefixes that may be present if saved from a full model
353
+ for prefix in (
354
+ "model.acoustic_connector.", "model.semantic_connector.",
355
+ "acoustic_connector.", "semantic_connector.",
356
+ ):
357
+ if any(k.startswith(prefix) for k in sd):
358
+ sd = _strip_prefix(sd, prefix)
359
+ break
360
+
361
+ connector.load_state_dict(sd, strict=True)
362
+ connector.eval()
363
+ log.info(" Connector loaded from %s (%d -> %d)", path.name, input_dim, output_dim)
364
+ return connector
365
+
366
+
367
+ def _infer_diffusion_head_config(sd: Dict):
368
+ """Infer VibeVoiceDiffusionHeadConfig from state-dict tensor shapes."""
369
+ from vibevoice.modular.configuration_vibevoice import VibeVoiceDiffusionHeadConfig
370
+
371
+ # Find noisy_images_proj.weight regardless of prefix
372
+ proj_w = None
373
+ for k, v in sd.items():
374
+ if k.endswith("noisy_images_proj.weight"):
375
+ proj_w = v
376
+ break
377
+ if proj_w is None:
378
+ raise KeyError(
379
+ "'noisy_images_proj.weight' not found in diffusion head state dict. "
380
+ f"Available keys (first 10): {list(sd.keys())[:10]}"
381
+ )
382
+
383
+ hidden_size, latent_size = proj_w.shape
384
+
385
+ # Count layers by looking for per-layer norm weights
386
+ head_layers = sum(
387
+ 1 for k in sd if ".norm.weight" in k and k.split(".norm.weight")[0].startswith("layers.")
388
+ )
389
+ head_layers = max(head_layers, 1)
390
+
391
+ # Infer FFN ratio
392
+ ffn_w = next((v for k, v in sd.items() if k.endswith("ffn.gate_proj.weight")), None)
393
+ head_ffn_ratio = (ffn_w.shape[0] / hidden_size) if ffn_w is not None else 3.0
394
+
395
+ cfg = VibeVoiceDiffusionHeadConfig(
396
+ hidden_size=hidden_size,
397
+ latent_size=latent_size,
398
+ head_layers=head_layers,
399
+ head_ffn_ratio=head_ffn_ratio,
400
+ )
401
+ log.info(
402
+ " Diffusion head config hidden=%d latent=%d layers=%d ffn_ratio=%.1f",
403
+ hidden_size, latent_size, head_layers, head_ffn_ratio,
404
+ )
405
+ return cfg
406
+
407
+
408
+ def _load_diffusion_head(path: Path, device: torch.device):
409
+ from vibevoice.modular.modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
410
+
411
+ sd = _load_pth_state(path)
412
+ for prefix in ("model.prediction_head.", "prediction_head."):
413
+ if any(k.startswith(prefix) for k in sd):
414
+ sd = _strip_prefix(sd, prefix)
415
+ break
416
+
417
+ cfg = _infer_diffusion_head_config(sd)
418
+ head = VibeVoiceDiffusionHead(cfg).to(device)
419
+ head.load_state_dict(sd, strict=True)
420
+ head.eval()
421
+ return head, cfg
422
+
423
+
424
+ def _load_llm_embed_and_head(device: torch.device):
425
+ """Load only embed_tokens + lm_head from the Qwen2.5-7B LLM to save RAM."""
426
+ from transformers import AutoModelForCausalLM
427
+
428
+ log.info(" Loading Qwen2.5-7B (embed_tokens + lm_head only - may take a few minutes) …")
429
+ llm = AutoModelForCausalLM.from_pretrained(
430
+ str(CONTENT / "llm"),
431
+ torch_dtype=torch.float32,
432
+ device_map="cpu",
433
+ low_cpu_mem_usage=True,
434
+ )
435
+ embed_tokens = llm.model.embed_tokens.to(device).eval()
436
+ lm_head = llm.lm_head.to(device).eval()
437
+ del llm
438
+ if torch.cuda.is_available():
439
+ torch.cuda.empty_cache()
440
+ log.info(" Qwen2.5-7B embed_tokens + lm_head ready")
441
+ return embed_tokens, lm_head
442
+
443
+
444
+ # ---------------------------------------------------------------------------
445
+ # 5. Per-component export functions
446
+ # ---------------------------------------------------------------------------
447
+
448
+ def _dynamic_axes_audio():
449
+ return {
450
+ "audio": {0: "batch", 2: "time"},
451
+ "acoustic_latent": {0: "batch", 1: "frames"},
452
+ }
453
+
454
+
455
+ def export_acoustic_encoder(out_dir: Path, device: torch.device, dynamo: bool) -> None:
456
+ log.info("Exporting acoustic_encoder.onnx …")
457
+ tok = _load_acoustic_tokenizer(device)
458
+ wrapper = AcousticEncoderONNX(tok).to(device)
459
+
460
+ audio = torch.randn(1, 1, REF_AUDIO_LEN, device=device)
461
+ _export_onnx(
462
+ wrapper, (audio,),
463
+ out_dir / "acoustic_encoder.onnx",
464
+ input_names=["audio"],
465
+ output_names=["acoustic_latent"],
466
+ dynamic_axes=_dynamic_axes_audio(),
467
+ use_dynamo=dynamo,
468
+ )
469
+
470
+
471
+ def export_acoustic_decoder(out_dir: Path, device: torch.device, dynamo: bool) -> None:
472
+ log.info("Exporting acoustic_decoder.onnx …")
473
+ tok = _load_acoustic_tokenizer(device)
474
+ wrapper = AcousticDecoderONNX(tok).to(device)
475
+
476
+ ref_frames = REF_AUDIO_LEN // HOP_LENGTH # 30
477
+ latents = torch.randn(1, ACOUSTIC_VAE_DIM, ref_frames, device=device)
478
+ _export_onnx(
479
+ wrapper, (latents,),
480
+ out_dir / "acoustic_decoder.onnx",
481
+ input_names=["acoustic_latent"],
482
+ output_names=["audio"],
483
+ dynamic_axes={
484
+ "acoustic_latent": {0: "batch", 2: "frames"},
485
+ "audio": {0: "batch", 2: "time"},
486
+ },
487
+ use_dynamo=dynamo,
488
+ )
489
+
490
+
491
+ def export_semantic_encoder(out_dir: Path, device: torch.device, dynamo: bool) -> None:
492
+ log.info("Exporting semantic_encoder.onnx …")
493
+ tok = _load_semantic_tokenizer(device)
494
+ wrapper = SemanticEncoderONNX(tok).to(device)
495
+
496
+ audio = torch.randn(1, 1, REF_AUDIO_LEN, device=device)
497
+ _export_onnx(
498
+ wrapper, (audio,),
499
+ out_dir / "semantic_encoder.onnx",
500
+ input_names=["audio"],
501
+ output_names=["semantic_latent"],
502
+ dynamic_axes={
503
+ "audio": {0: "batch", 2: "time"},
504
+ "semantic_latent": {0: "batch", 1: "frames"},
505
+ },
506
+ use_dynamo=dynamo,
507
+ )
508
+
509
+
510
+ def export_acoustic_connector(out_dir: Path, device: torch.device, dynamo: bool) -> None:
511
+ log.info("Exporting acoustic_connector.onnx …")
512
+ conn = _load_connector(
513
+ CONTENT / "acoustic_connector.pth", ACOUSTIC_VAE_DIM, LM_HIDDEN, device
514
+ )
515
+ wrapper = SpeechConnectorONNX(conn).to(device)
516
+
517
+ ref_frames = REF_AUDIO_LEN // HOP_LENGTH
518
+ latents = torch.randn(1, ref_frames, ACOUSTIC_VAE_DIM, device=device)
519
+ _export_onnx(
520
+ wrapper, (latents,),
521
+ out_dir / "acoustic_connector.onnx",
522
+ input_names=["acoustic_latent"],
523
+ output_names=["acoustic_features"],
524
+ dynamic_axes={
525
+ "acoustic_latent": {0: "batch", 1: "frames"},
526
+ "acoustic_features": {0: "batch", 1: "frames"},
527
+ },
528
+ use_dynamo=dynamo,
529
+ )
530
+
531
+
532
+ def export_semantic_connector(out_dir: Path, device: torch.device, dynamo: bool) -> None:
533
+ log.info("Exporting semantic_connector.onnx …")
534
+ conn = _load_connector(
535
+ CONTENT / "semantic_connector.pth", SEMANTIC_VAE_DIM, LM_HIDDEN, device
536
+ )
537
+ wrapper = SpeechConnectorONNX(conn).to(device)
538
+
539
+ ref_frames = REF_AUDIO_LEN // HOP_LENGTH
540
+ latents = torch.randn(1, ref_frames, SEMANTIC_VAE_DIM, device=device)
541
+ _export_onnx(
542
+ wrapper, (latents,),
543
+ out_dir / "semantic_connector.onnx",
544
+ input_names=["semantic_latent"],
545
+ output_names=["semantic_features"],
546
+ dynamic_axes={
547
+ "semantic_latent": {0: "batch", 1: "frames"},
548
+ "semantic_features": {0: "batch", 1: "frames"},
549
+ },
550
+ use_dynamo=dynamo,
551
+ )
552
+
553
+
554
+ def export_diffusion_head(out_dir: Path, device: torch.device, dynamo: bool) -> None:
555
+ log.info("Exporting diffusion_head.onnx …")
556
+ head, cfg = _load_diffusion_head(CONTENT / "head.pth", device)
557
+ wrapper = DiffusionHeadONNX(head).to(device)
558
+
559
+ N = 4 # batch of latent tokens
560
+ noisy = torch.randn(N, cfg.latent_size, device=device)
561
+ timesteps = torch.randint(0, 1000, (N,), dtype=torch.float32, device=device)
562
+ condition = torch.randn(N, cfg.hidden_size, device=device)
563
+
564
+ _export_onnx(
565
+ wrapper, (noisy, timesteps, condition),
566
+ out_dir / "diffusion_head.onnx",
567
+ input_names=["noisy_latent", "timesteps", "condition"],
568
+ output_names=["predicted_noise"],
569
+ dynamic_axes={
570
+ "noisy_latent": {0: "N"},
571
+ "timesteps": {0: "N"},
572
+ "condition": {0: "N"},
573
+ "predicted_noise": {0: "N"},
574
+ },
575
+ use_dynamo=dynamo,
576
+ )
577
+
578
+
579
+ def export_llm_parts(out_dir: Path, device: torch.device, dynamo: bool) -> None:
580
+ log.info("Exporting llm_embed_tokens.onnx …")
581
+ embed_tokens, lm_head = _load_llm_embed_and_head(device)
582
+
583
+ token_ids = torch.randint(0, LM_VOCAB, (1, 32), device=device)
584
+ _export_onnx(
585
+ LLMEmbedTokensONNX(embed_tokens), (token_ids,),
586
+ out_dir / "llm_embed_tokens.onnx",
587
+ input_names=["input_ids"],
588
+ output_names=["embeddings"],
589
+ dynamic_axes={
590
+ "input_ids": {0: "batch", 1: "seq"},
591
+ "embeddings": {0: "batch", 1: "seq"},
592
+ },
593
+ use_dynamo=dynamo,
594
+ )
595
+
596
+ log.info("Exporting lm_head.onnx …")
597
+ hidden = torch.randn(1, 32, LM_HIDDEN, device=device)
598
+ _export_onnx(
599
+ LMHeadONNX(lm_head), (hidden,),
600
+ out_dir / "lm_head.onnx",
601
+ input_names=["hidden_states"],
602
+ output_names=["logits"],
603
+ dynamic_axes={
604
+ "hidden_states": {0: "batch", 1: "seq"},
605
+ "logits": {0: "batch", 1: "seq"},
606
+ },
607
+ use_dynamo=dynamo,
608
+ )
609
+
610
+
611
+ # ---------------------------------------------------------------------------
612
+ # 6. CLI
613
+ # ---------------------------------------------------------------------------
614
+
615
+ ALL_COMPONENTS = [
616
+ "acoustic_encoder",
617
+ "acoustic_decoder",
618
+ "semantic_encoder",
619
+ "acoustic_connector",
620
+ "semantic_connector",
621
+ "diffusion_head",
622
+ "llm",
623
+ ]
624
+
625
+ EXPORT_FNS = {
626
+ "acoustic_encoder": export_acoustic_encoder,
627
+ "acoustic_decoder": export_acoustic_decoder,
628
+ "semantic_encoder": export_semantic_encoder,
629
+ "acoustic_connector": export_acoustic_connector,
630
+ "semantic_connector": export_semantic_connector,
631
+ "diffusion_head": export_diffusion_head,
632
+ "llm": export_llm_parts,
633
+ }
634
+
635
+
636
+ def main() -> int:
637
+ parser = argparse.ArgumentParser(
638
+ description="Export VibeVoice ASR components to ONNX opset 20",
639
+ formatter_class=argparse.RawDescriptionHelpFormatter,
640
+ epilog=__doc__,
641
+ )
642
+ parser.add_argument(
643
+ "--output-dir", default="onnx_outputs",
644
+ help="Directory where ONNX files are written (default: onnx_outputs/)",
645
+ )
646
+ parser.add_argument(
647
+ "--device", default="cpu",
648
+ help="PyTorch device string, e.g. 'cpu' or 'cuda:0' (default: cpu)",
649
+ )
650
+ parser.add_argument(
651
+ "--skip-llm", action="store_true",
652
+ help="Skip llm_embed_tokens + lm_head (saves ~28 GB RAM for the 7 B LLM)",
653
+ )
654
+ parser.add_argument(
655
+ "--dynamo", action="store_true",
656
+ help=(
657
+ "Use torch.onnx.dynamo_export for fully dynamic shapes "
658
+ "(requires PyTorch >= 2.1). Slower but handles variable audio lengths."
659
+ ),
660
+ )
661
+ parser.add_argument(
662
+ "--components", nargs="+", choices=ALL_COMPONENTS,
663
+ help="Subset of components to export (default: all)",
664
+ )
665
+ args = parser.parse_args()
666
+
667
+ out_dir = Path(args.output_dir)
668
+ out_dir.mkdir(parents=True, exist_ok=True)
669
+ device = torch.device(args.device)
670
+
671
+ log.info(
672
+ "VibeVoice ASR -> ONNX opset %d | device=%s | output=%s | dynamo=%s",
673
+ OPSET, device, out_dir, args.dynamo,
674
+ )
675
+ log.info("PyTorch %s", torch.__version__)
676
+
677
+ # Dependency check
678
+ try:
679
+ import onnx
680
+ log.info("onnx %s", onnx.__version__)
681
+ except ImportError:
682
+ log.error("'onnx' not installed. Run: pip install onnx onnxruntime")
683
+ return 1
684
+
685
+ _register_vibevoice()
686
+
687
+ # Determine which components to export
688
+ want = set(args.components) if args.components else set(ALL_COMPONENTS)
689
+ if args.skip_llm:
690
+ want.discard("llm")
691
+
692
+ succeeded: List[str] = []
693
+ failed: List[str] = []
694
+
695
+ for name in ALL_COMPONENTS:
696
+ if name not in want:
697
+ continue
698
+ fn = EXPORT_FNS[name]
699
+ try:
700
+ fn(out_dir, device, args.dynamo)
701
+ succeeded.append(name)
702
+ except Exception as exc:
703
+ log.error("FAILED %s: %s", name, exc, exc_info=True)
704
+ failed.append(name)
705
+
706
+ log.info("")
707
+ log.info("=== Summary ===")
708
+ log.info("Succeeded : %s", ", ".join(succeeded) if succeeded else "(none)")
709
+ if failed:
710
+ log.warning("Failed : %s", ", ".join(failed))
711
+ log.info("Output dir: %s", out_dir.resolve())
712
+
713
+ if not failed:
714
+ log.info("")
715
+ log.info("Inference note:")
716
+ log.info(
717
+ " Tokenizer encoders were exported with REF_AUDIO_LEN=%d samples (%g s).",
718
+ REF_AUDIO_LEN, REF_AUDIO_LEN / SAMPLE_RATE,
719
+ )
720
+ log.info(
721
+ " For variable-length inference, pad audio to multiples of %d samples "
722
+ "(%g ms) before feeding to acoustic_encoder / semantic_encoder.",
723
+ HOP_LENGTH, HOP_LENGTH / SAMPLE_RATE * 1000,
724
+ )
725
+ log.info(
726
+ " Or re-export with --dynamo for fully dynamic shape support."
727
+ )
728
+
729
+ return 1 if failed else 0
730
+
731
+
732
+ if __name__ == "__main__":
733
+ sys.exit(main())
734
+