acul3 commited on
Commit
b3779d4
Β·
verified Β·
1 Parent(s): 4005d54

Upload scripts/export_vocoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/export_vocoder.py +212 -0
scripts/export_vocoder.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Phase 5: Export Speech Tokenizer Decoder (Vocoder) to ExecuTorch .pte
4
+ ======================================================================
5
+ The vocoder converts codec tokens β†’ audio waveform.
6
+
7
+ Architecture:
8
+ codes [B, 16, T] β†’ VQ decode β†’ [B, codebook_dim, T]
9
+ β†’ Conv1d β†’ Transformer (8 layers) β†’ Conv1d
10
+ β†’ Upsample (2x, 2x) via ConvTranspose1d + ConvNeXt
11
+ β†’ Decoder (8x, 5x, 4x, 3x) via ConvTranspose1d + SnakeBeta + ResBlocks
12
+ β†’ Conv1d β†’ waveform [B, 1, T*1920]
13
+
14
+ Total upsample: 2*2*8*5*4*3 = 3840x (but code downsample is 1920x, so net 1920x)
15
+ Wait β€” the decoder forward uses total_upsample which is upsample_rates * upsampling_ratios.
16
+ """
17
+
18
+ import sys
19
+ import os
20
+ import copy
21
+ import time
22
+ import math
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
29
+ VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
30
+ QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
31
+ OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
32
+
33
+ if VENV_SITE not in sys.path:
34
+ sys.path.insert(0, VENV_SITE)
35
+ if QWEN_TTS_SRC not in sys.path:
36
+ sys.path.insert(0, QWEN_TTS_SRC)
37
+
38
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
39
+
40
+ # Fixed code length for export (50 frames β‰ˆ 4 seconds of audio)
41
+ FIXED_CODE_LEN = 50
42
+ NUM_QUANTIZERS = 16
43
+
44
+ print("=" * 70)
45
+ print("PHASE 5: Export Vocoder (Speech Tokenizer Decoder) β†’ .pte")
46
+ print("=" * 70)
47
+
48
+ # ── 1. Load Model ───────────────────────────────────────────────────
49
+
50
+ print("\n[1/5] Loading model...")
51
+ from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
52
+ from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
53
+
54
+ config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
55
+ model = Qwen3TTSForConditionalGeneration.from_pretrained(
56
+ MODEL_PATH, config=config, dtype=torch.float32,
57
+ attn_implementation="sdpa", device_map="cpu",
58
+ )
59
+ model.eval()
60
+ print(" Model loaded.")
61
+
62
+ # ── 2. Create Vocoder Wrapper ────────────────────────────────────────
63
+
64
+ print("\n[2/5] Creating vocoder wrapper...")
65
+
66
+ # The decoder has dynamic padding calculations that depend on input length.
67
+ # With a FIXED input length, these become constants. We wrap the original
68
+ # decoder directly and let torch.export trace through the fixed-size logic.
69
+
70
+ class VocoderForExport(nn.Module):
71
+ """
72
+ Wraps the speech tokenizer decoder for export.
73
+
74
+ Bypasses chunked_decode and calls forward() directly.
75
+ Input: codes [1, num_quantizers, code_len] β€” all int64
76
+ Output: waveform [1, 1, code_len * decode_upsample_rate]
77
+ """
78
+
79
+ def __init__(self, original_decoder):
80
+ super().__init__()
81
+ self.decoder = copy.deepcopy(original_decoder)
82
+
83
+ def forward(self, codes: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Args:
86
+ codes: [1, 16, FIXED_CODE_LEN] β€” LongTensor of codec indices
87
+ Returns:
88
+ waveform: [1, 1, FIXED_CODE_LEN * upsample] β€” float waveform in [-1, 1]
89
+ """
90
+ return self.decoder(codes)
91
+
92
+
93
+ vocoder = VocoderForExport(model.speech_tokenizer.model.decoder)
94
+ vocoder.eval()
95
+
96
+ param_count = sum(p.numel() for p in vocoder.parameters())
97
+ print(f" Vocoder parameters: {param_count / 1e6:.1f}M")
98
+
99
+ # ── 3. Validate ─────────────────────────────────────────────────────
100
+
101
+ print("\n[3/5] Validating vocoder wrapper...")
102
+
103
+ test_codes = torch.randint(0, 2048, (1, NUM_QUANTIZERS, FIXED_CODE_LEN))
104
+
105
+ with torch.no_grad():
106
+ # Test original decoder
107
+ orig_wav = model.speech_tokenizer.model.decoder(test_codes)
108
+ # Test wrapper
109
+ wrap_wav = vocoder(test_codes)
110
+
111
+ print(f" Input codes shape: {list(test_codes.shape)}")
112
+ print(f" Original output shape: {list(orig_wav.shape)}")
113
+ print(f" Wrapper output shape: {list(wrap_wav.shape)}")
114
+
115
+ cos_sim = F.cosine_similarity(orig_wav.flatten().unsqueeze(0),
116
+ wrap_wav.flatten().unsqueeze(0)).item()
117
+ max_diff = (orig_wav - wrap_wav).abs().max().item()
118
+ print(f" Cosine similarity: {cos_sim:.6f}")
119
+ print(f" Max abs difference: {max_diff:.2e}")
120
+ assert cos_sim > 0.999, f"Mismatch! cos_sim={cos_sim}"
121
+ print(" PASS β€” vocoder validated")
122
+
123
+ upsample_rate = wrap_wav.shape[-1] // FIXED_CODE_LEN
124
+ print(f" Upsample rate: {upsample_rate}x")
125
+ print(f" Output duration: {wrap_wav.shape[-1] / 24000:.1f}s at 24kHz")
126
+
127
+ # ── 4. torch.export ─────────────────────────────────────────────────
128
+
129
+ print("\n[4/5] Running torch.export...")
130
+ t0 = time.time()
131
+
132
+ example_input = (test_codes,)
133
+
134
+ try:
135
+ exported = torch.export.export(
136
+ vocoder,
137
+ example_input,
138
+ strict=False,
139
+ )
140
+ print(f" torch.export succeeded in {time.time() - t0:.1f}s")
141
+ print(f" Graph nodes: {len(exported.graph.nodes)}")
142
+ except Exception as e:
143
+ print(f" torch.export FAILED: {e}")
144
+ exported = None
145
+
146
+ # ── 5. Lower to .pte ────────────────────────────────────────────────
147
+
148
+ print("\n[5/5] Lowering to ExecuTorch .pte...")
149
+ t0 = time.time()
150
+
151
+ if exported is not None:
152
+ try:
153
+ from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
154
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
155
+
156
+ edge = to_edge_transform_and_lower(
157
+ exported,
158
+ compile_config=EdgeCompileConfig(_check_ir_validity=False),
159
+ partitioner=[XnnpackPartitioner()],
160
+ )
161
+ et_program = edge.to_executorch()
162
+
163
+ pte_path = os.path.join(OUTPUT_DIR, "vocoder.pte")
164
+ with open(pte_path, "wb") as f:
165
+ f.write(et_program.buffer)
166
+
167
+ pte_size = os.path.getsize(pte_path) / 1e6
168
+ print(f" .pte saved: {pte_path}")
169
+ print(f" .pte size: {pte_size:.1f} MB")
170
+ print(f" Lowered in {time.time() - t0:.1f}s")
171
+
172
+ except Exception as e:
173
+ print(f" ExecuTorch lowering failed: {e}")
174
+ if exported is not None:
175
+ pt2_path = os.path.join(OUTPUT_DIR, "vocoder.pt2")
176
+ torch.export.save(exported, pt2_path)
177
+ print(f" Saved exported program: {pt2_path}")
178
+
179
+ # Validate .pte
180
+ if os.path.exists(os.path.join(OUTPUT_DIR, "vocoder.pte")):
181
+ print("\n Validating .pte execution...")
182
+ try:
183
+ from executorch.runtime import Runtime
184
+
185
+ runtime = Runtime.get()
186
+ program = runtime.load_program(
187
+ open(os.path.join(OUTPUT_DIR, "vocoder.pte"), "rb").read()
188
+ )
189
+ method = program.load_method("forward")
190
+ pte_out = method.execute([test_codes])
191
+ if isinstance(pte_out, (list, tuple)):
192
+ pte_out = pte_out[0]
193
+ with torch.no_grad():
194
+ ref_out = vocoder(test_codes)
195
+ cos_pte = F.cosine_similarity(
196
+ ref_out.flatten().unsqueeze(0),
197
+ pte_out.flatten().unsqueeze(0)
198
+ ).item()
199
+ print(f" .pte vs PyTorch cosine sim: {cos_pte:.6f}")
200
+ except Exception as e:
201
+ print(f" .pte validation: {e}")
202
+ else:
203
+ print(" No exported program to lower.")
204
+ # Save state dict as fallback
205
+ torch.save(vocoder.state_dict(), os.path.join(OUTPUT_DIR, "vocoder_state_dict.pt"))
206
+ print(f" Saved state dict: {OUTPUT_DIR}/vocoder_state_dict.pt")
207
+
208
+ print("\n" + "=" * 70)
209
+ print("Phase 5 complete!")
210
+ print(f" Fixed code length: {FIXED_CODE_LEN} frames")
211
+ print(f" Output: {FIXED_CODE_LEN * upsample_rate} samples ({FIXED_CODE_LEN * upsample_rate / 24000:.1f}s)")
212
+ print("=" * 70)