acul3 commited on
Commit
4005d54
Β·
verified Β·
1 Parent(s): f2deaf5

Upload scripts/export_code_predictor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/export_code_predictor.py +390 -0
scripts/export_code_predictor.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Phase 4: Export Code Predictor to ExecuTorch .pte
4
+ ==================================================
5
+ The code predictor is a smaller 5-layer transformer (175M params) that
6
+ takes the talker's hidden state + first codebook token and autoregressively
7
+ generates the remaining 15 codebook tokens.
8
+
9
+ Architecture:
10
+ - hidden_size=1024, 5 layers, 16 heads, 8 kv_heads, head_dim=128
11
+ - small_to_mtp_projection: Linear(2048β†’1024) β€” projects talker hidden β†’ predictor
12
+ - 15 lm_heads: Linear(1024β†’2048) each (one per code group)
13
+ - 15 codec_embeddings: Embedding(2048, 2048) each
14
+
15
+ During inference (called once per talker decode step):
16
+ Step 0 (prefill): concat(projected_talker_hidden, codec_embed_0(first_token)) β†’ 2 tokens
17
+ Steps 1-14: predict next code group token β†’ embed it β†’ feed back
18
+
19
+ We export this as a static-KV-cache transformer similar to the talker.
20
+ """
21
+
22
+ import sys
23
+ import os
24
+ import copy
25
+ import time
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
31
+ VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
32
+ QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
33
+ OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
34
+
35
+ if VENV_SITE not in sys.path:
36
+ sys.path.insert(0, VENV_SITE)
37
+ if QWEN_TTS_SRC not in sys.path:
38
+ sys.path.insert(0, QWEN_TTS_SRC)
39
+
40
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
41
+
42
+ # ── Configuration ────────────────────────────────────────────────────
43
+ MAX_SEQ_LEN = 17 # prefill=2, then 15 decode steps
44
+ BATCH_SIZE = 1
45
+ CP_NUM_LAYERS = 5
46
+ CP_NUM_KV_HEADS = 8
47
+ CP_HEAD_DIM = 128
48
+ CP_NUM_HEADS = 16
49
+ CP_HIDDEN_SIZE = 1024
50
+ CP_INTERMEDIATE_SIZE = 3072
51
+ CP_VOCAB_SIZE = 2048
52
+ CP_NUM_CODE_GROUPS = 16 # total groups (predict 15, first comes from talker)
53
+ TALKER_HIDDEN_SIZE = 2048
54
+
55
+ print("=" * 70)
56
+ print("PHASE 4: Export Code Predictor β†’ .pte")
57
+ print("=" * 70)
58
+
59
+ # ── 1. Load Model ───────────────────────────────────────────────────
60
+
61
+ print("\n[1/5] Loading model...")
62
+ from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
63
+ from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
64
+
65
+ config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
66
+ model = Qwen3TTSForConditionalGeneration.from_pretrained(
67
+ MODEL_PATH, config=config, dtype=torch.float32,
68
+ attn_implementation="sdpa", device_map="cpu",
69
+ )
70
+ model.eval()
71
+ print(" Model loaded.")
72
+
73
+ # ── 2. Build Export-Ready Code Predictor ─────────────────────────────
74
+
75
+ print("\n[2/5] Building export-ready code predictor wrapper...")
76
+
77
+
78
+ class RMSNorm(nn.Module):
79
+ def __init__(self, dim, eps=1e-6):
80
+ super().__init__()
81
+ self.weight = nn.Parameter(torch.ones(dim))
82
+ self.eps = eps
83
+
84
+ def forward(self, x):
85
+ dtype = x.dtype
86
+ x = x.float()
87
+ v = x.pow(2).mean(-1, keepdim=True)
88
+ x = x * torch.rsqrt(v + self.eps)
89
+ return (self.weight * x).to(dtype)
90
+
91
+
92
+ def rotate_half(x):
93
+ x1 = x[..., : x.shape[-1] // 2]
94
+ x2 = x[..., x.shape[-1] // 2 :]
95
+ return torch.cat((-x2, x1), dim=-1)
96
+
97
+
98
+ class CPAttentionForExport(nn.Module):
99
+ """Code predictor attention layer with static KV cache."""
100
+
101
+ def __init__(self, original_attn, layer_idx):
102
+ super().__init__()
103
+ self.layer_idx = layer_idx
104
+ self.head_dim = CP_HEAD_DIM
105
+ self.num_heads = CP_NUM_HEADS
106
+ self.num_kv_heads = CP_NUM_KV_HEADS
107
+ self.num_kv_groups = CP_NUM_HEADS // CP_NUM_KV_HEADS
108
+ self.scaling = CP_HEAD_DIM ** -0.5
109
+
110
+ self.q_proj = copy.deepcopy(original_attn.q_proj)
111
+ self.k_proj = copy.deepcopy(original_attn.k_proj)
112
+ self.v_proj = copy.deepcopy(original_attn.v_proj)
113
+ self.o_proj = copy.deepcopy(original_attn.o_proj)
114
+ self.q_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6)
115
+ self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight)
116
+ self.k_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6)
117
+ self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight)
118
+
119
+ def forward(self, hidden_states, cos, sin, cache_position,
120
+ k_cache, v_cache, attn_mask):
121
+ bsz, seq_len, _ = hidden_states.shape
122
+
123
+ q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim)
124
+ q = self.q_norm(q).transpose(1, 2)
125
+ k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim)
126
+ k = self.k_norm(k).transpose(1, 2)
127
+ v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
128
+
129
+ q = (q * cos) + (rotate_half(q) * sin)
130
+ k = (k * cos) + (rotate_half(k) * sin)
131
+
132
+ k_cache = k_cache.clone()
133
+ v_cache = v_cache.clone()
134
+ k_cache[:, :, cache_position, :] = k
135
+ v_cache[:, :, cache_position, :] = v
136
+
137
+ k_expanded = k_cache.unsqueeze(2).repeat(
138
+ 1, 1, self.num_kv_groups, 1, 1
139
+ ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim)
140
+ v_expanded = v_cache.unsqueeze(2).repeat(
141
+ 1, 1, self.num_kv_groups, 1, 1
142
+ ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim)
143
+
144
+ attn_output = F.scaled_dot_product_attention(
145
+ q, k_expanded, v_expanded,
146
+ attn_mask=attn_mask,
147
+ scale=self.scaling,
148
+ )
149
+
150
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1)
151
+ attn_output = self.o_proj(attn_output)
152
+ return attn_output, k_cache, v_cache
153
+
154
+
155
+ class CPMLP(nn.Module):
156
+ def __init__(self, original_mlp):
157
+ super().__init__()
158
+ self.gate_proj = copy.deepcopy(original_mlp.gate_proj)
159
+ self.up_proj = copy.deepcopy(original_mlp.up_proj)
160
+ self.down_proj = copy.deepcopy(original_mlp.down_proj)
161
+
162
+ def forward(self, x):
163
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
164
+
165
+
166
+ class CPLayerForExport(nn.Module):
167
+ def __init__(self, original_layer, layer_idx):
168
+ super().__init__()
169
+ self.attn = CPAttentionForExport(original_layer.self_attn, layer_idx)
170
+ self.mlp = CPMLP(original_layer.mlp)
171
+ self.input_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
172
+ self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight)
173
+ self.post_attn_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
174
+ self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight)
175
+
176
+ def forward(self, hidden_states, cos, sin, cache_position,
177
+ k_cache, v_cache, attn_mask):
178
+ residual = hidden_states
179
+ hidden_states = self.input_norm(hidden_states)
180
+ attn_out, k_cache, v_cache = self.attn(
181
+ hidden_states, cos, sin, cache_position,
182
+ k_cache, v_cache, attn_mask
183
+ )
184
+ hidden_states = residual + attn_out
185
+
186
+ residual = hidden_states
187
+ hidden_states = self.post_attn_norm(hidden_states)
188
+ hidden_states = self.mlp(hidden_states)
189
+ hidden_states = residual + hidden_states
190
+
191
+ return hidden_states, k_cache, v_cache
192
+
193
+
194
+ class CodePredictorForExport(nn.Module):
195
+ """
196
+ Export-ready code predictor backbone.
197
+
198
+ Input: pre-projected inputs_embeds (already through small_to_mtp_projection)
199
+ Output: hidden states (caller applies the appropriate lm_head externally)
200
+
201
+ For the full 16-codebook prediction:
202
+ 1. Python builds inputs_embeds from talker hidden + codec embeddings
203
+ 2. This module runs the transformer
204
+ 3. Python takes hidden[:, step_idx, :] and applies lm_head[step_idx]
205
+ """
206
+
207
+ def __init__(self, original_cp):
208
+ super().__init__()
209
+
210
+ # Transformer layers
211
+ self.layers = nn.ModuleList()
212
+ for i, layer in enumerate(original_cp.model.layers):
213
+ self.layers.append(CPLayerForExport(layer, i))
214
+
215
+ # Final norm
216
+ self.norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
217
+ self.norm.weight = copy.deepcopy(original_cp.model.norm.weight)
218
+
219
+ # Projection from talker hidden to code predictor hidden
220
+ self.small_to_mtp_projection = copy.deepcopy(original_cp.small_to_mtp_projection)
221
+
222
+ # LM heads (15 heads, one per code group 1..15)
223
+ self.lm_heads = nn.ModuleList()
224
+ for head in original_cp.lm_head:
225
+ self.lm_heads.append(copy.deepcopy(head))
226
+
227
+ # Rotary embedding
228
+ orig_rope = original_cp.model.rotary_emb
229
+ self.register_buffer("inv_freq", orig_rope.inv_freq.clone())
230
+ self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0)
231
+
232
+ def _compute_rope(self, position_ids, device, dtype):
233
+ pos = position_ids.float() # [B, seq_len]
234
+ inv_freq = self.inv_freq.float().to(device)
235
+ freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)
236
+ emb = torch.cat([freqs, freqs], dim=-1)
237
+ cos = (emb.cos() * self.rope_scaling).to(dtype)
238
+ sin = (emb.sin() * self.rope_scaling).to(dtype)
239
+ return cos.unsqueeze(1), sin.unsqueeze(1)
240
+
241
+ def forward(self, inputs_embeds, position_ids, cache_position, attn_mask,
242
+ *kv_cache_flat):
243
+ """
244
+ Args:
245
+ inputs_embeds: [B, seq_len, talker_hidden_size] β€” NOT YET projected
246
+ position_ids: [B, seq_len]
247
+ cache_position: [seq_len]
248
+ attn_mask: [B, 1, seq_len, MAX_SEQ_LEN]
249
+ *kv_cache_flat: 5 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim]
250
+
251
+ Returns:
252
+ hidden_states: [B, seq_len, CP_HIDDEN_SIZE] β€” apply lm_head externally
253
+ *updated_kv_cache
254
+ """
255
+ # Project from talker hidden β†’ code predictor hidden
256
+ hidden_states = self.small_to_mtp_projection(inputs_embeds)
257
+
258
+ cos, sin = self._compute_rope(position_ids, hidden_states.device, hidden_states.dtype)
259
+
260
+ updated_kv = []
261
+ for i, layer in enumerate(self.layers):
262
+ k_cache = kv_cache_flat[i * 2]
263
+ v_cache = kv_cache_flat[i * 2 + 1]
264
+ hidden_states, new_k, new_v = layer(
265
+ hidden_states, cos, sin, cache_position,
266
+ k_cache, v_cache, attn_mask
267
+ )
268
+ updated_kv.append(new_k)
269
+ updated_kv.append(new_v)
270
+
271
+ hidden_states = self.norm(hidden_states)
272
+
273
+ return (hidden_states, *updated_kv)
274
+
275
+
276
+ print(" Constructing CodePredictorForExport...")
277
+ t0 = time.time()
278
+ export_cp = CodePredictorForExport(model.talker.code_predictor)
279
+ export_cp.eval()
280
+ print(f" Done in {time.time() - t0:.1f}s")
281
+
282
+ param_count = sum(p.numel() for p in export_cp.parameters())
283
+ print(f" Parameters: {param_count / 1e6:.1f}M")
284
+
285
+ # ── 3. Validate ─────────────────────────────────────────────────────
286
+
287
+ print("\n[3/5] Validating wrapper...")
288
+
289
+ # Prefill: 2 tokens (projected_talker_hidden + first_codec_embed)
290
+ seq_len = 2
291
+ test_embeds = torch.randn(BATCH_SIZE, seq_len, TALKER_HIDDEN_SIZE)
292
+ test_pos = torch.arange(seq_len).unsqueeze(0).expand(BATCH_SIZE, -1)
293
+ test_cache_pos = torch.arange(seq_len)
294
+
295
+ causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf'))
296
+ for i in range(seq_len):
297
+ causal_mask[:, :, i, :i + 1] = 0.0
298
+
299
+ kv_cache = []
300
+ for _ in range(CP_NUM_LAYERS):
301
+ kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM))
302
+ kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM))
303
+
304
+ with torch.no_grad():
305
+ outputs = export_cp(test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache)
306
+
307
+ hidden = outputs[0]
308
+ print(f" Hidden states shape: {list(hidden.shape)}") # [1, 2, 1024]
309
+ assert hidden.shape == (BATCH_SIZE, seq_len, CP_HIDDEN_SIZE)
310
+
311
+ # Apply lm_head to get logits for the first prediction step
312
+ logits_0 = export_cp.lm_heads[0](hidden[:, -1:, :])
313
+ print(f" Logits[0] shape: {list(logits_0.shape)}") # [1, 1, 2048]
314
+ assert logits_0.shape[-1] == CP_VOCAB_SIZE
315
+
316
+ # Decode step
317
+ decode_embeds = torch.randn(BATCH_SIZE, 1, TALKER_HIDDEN_SIZE)
318
+ decode_pos = torch.tensor([[seq_len]])
319
+ decode_cache_pos = torch.tensor([seq_len])
320
+ decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf'))
321
+ decode_mask[:, :, :, :seq_len + 1] = 0.0
322
+
323
+ updated_kv = list(outputs[1:])
324
+ with torch.no_grad():
325
+ decode_out = export_cp(decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv)
326
+
327
+ print(f" Decode hidden shape: {list(decode_out[0].shape)}")
328
+ print(" PASS β€” code predictor validated")
329
+
330
+ # ── 4. torch.export ─────────────────────────────────────────────────
331
+
332
+ print("\n[4/5] Running torch.export...")
333
+ t0 = time.time()
334
+
335
+ prefill_args = (test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache)
336
+
337
+ try:
338
+ exported = torch.export.export(export_cp, prefill_args, strict=False)
339
+ print(f" torch.export succeeded in {time.time() - t0:.1f}s")
340
+ print(f" Graph nodes: {len(exported.graph.nodes)}")
341
+ except Exception as e:
342
+ print(f" torch.export FAILED: {e}")
343
+ exported = None
344
+
345
+ # ── 5. Lower to .pte ────────────────────────────────────────────────
346
+
347
+ print("\n[5/5] Lowering to ExecuTorch .pte...")
348
+ t0 = time.time()
349
+
350
+ if exported is not None:
351
+ try:
352
+ from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
353
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
354
+
355
+ edge = to_edge_transform_and_lower(
356
+ exported,
357
+ compile_config=EdgeCompileConfig(_check_ir_validity=False),
358
+ partitioner=[XnnpackPartitioner()],
359
+ )
360
+ et_program = edge.to_executorch()
361
+
362
+ pte_path = os.path.join(OUTPUT_DIR, "code_predictor.pte")
363
+ with open(pte_path, "wb") as f:
364
+ f.write(et_program.buffer)
365
+
366
+ pte_size = os.path.getsize(pte_path) / 1e6
367
+ print(f" .pte saved: {pte_path}")
368
+ print(f" .pte size: {pte_size:.1f} MB")
369
+ print(f" Lowered in {time.time() - t0:.1f}s")
370
+
371
+ except Exception as e:
372
+ print(f" ExecuTorch lowering failed: {e}")
373
+ pt2_path = os.path.join(OUTPUT_DIR, "code_predictor.pt2")
374
+ torch.export.save(exported, pt2_path)
375
+ print(f" Saved: {pt2_path}")
376
+
377
+ # Also save the codec embeddings and lm_heads for the orchestration layer
378
+ torch.save({
379
+ "codec_embeddings": [emb.state_dict() for emb in model.talker.code_predictor.model.codec_embedding],
380
+ "lm_heads": [head.state_dict() for head in export_cp.lm_heads],
381
+ "small_to_mtp_projection": export_cp.small_to_mtp_projection.state_dict(),
382
+ }, os.path.join(OUTPUT_DIR, "code_predictor_extras.pt"))
383
+ print(f" Saved codec embeddings + lm_heads: {OUTPUT_DIR}/code_predictor_extras.pt")
384
+
385
+ print("\n" + "=" * 70)
386
+ print("Phase 4 complete!")
387
+ print(f" Max seq len: {MAX_SEQ_LEN}")
388
+ print(f" Parameters: {param_count / 1e6:.1f}M")
389
+ print(f" Vocab: {CP_VOCAB_SIZE}, Code groups: {CP_NUM_CODE_GROUPS}")
390
+ print("=" * 70)