File size: 15,272 Bytes
4005d54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 | #!/usr/bin/env python3
"""
Phase 4: Export Code Predictor to ExecuTorch .pte
==================================================
The code predictor is a smaller 5-layer transformer (175M params) that
takes the talker's hidden state + first codebook token and autoregressively
generates the remaining 15 codebook tokens.
Architecture:
- hidden_size=1024, 5 layers, 16 heads, 8 kv_heads, head_dim=128
- small_to_mtp_projection: Linear(2048β1024) β projects talker hidden β predictor
- 15 lm_heads: Linear(1024β2048) each (one per code group)
- 15 codec_embeddings: Embedding(2048, 2048) each
During inference (called once per talker decode step):
Step 0 (prefill): concat(projected_talker_hidden, codec_embed_0(first_token)) β 2 tokens
Steps 1-14: predict next code group token β embed it β feed back
We export this as a static-KV-cache transformer similar to the talker.
"""
import sys
import os
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ
MAX_SEQ_LEN = 17 # prefill=2, then 15 decode steps
BATCH_SIZE = 1
CP_NUM_LAYERS = 5
CP_NUM_KV_HEADS = 8
CP_HEAD_DIM = 128
CP_NUM_HEADS = 16
CP_HIDDEN_SIZE = 1024
CP_INTERMEDIATE_SIZE = 3072
CP_VOCAB_SIZE = 2048
CP_NUM_CODE_GROUPS = 16 # total groups (predict 15, first comes from talker)
TALKER_HIDDEN_SIZE = 2048
print("=" * 70)
print("PHASE 4: Export Code Predictor β .pte")
print("=" * 70)
# ββ 1. Load Model βββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
# ββ 2. Build Export-Ready Code Predictor βββββββββββββββββββββββββββββ
print("\n[2/5] Building export-ready code predictor wrapper...")
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
dtype = x.dtype
x = x.float()
v = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(v + self.eps)
return (self.weight * x).to(dtype)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CPAttentionForExport(nn.Module):
"""Code predictor attention layer with static KV cache."""
def __init__(self, original_attn, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = CP_HEAD_DIM
self.num_heads = CP_NUM_HEADS
self.num_kv_heads = CP_NUM_KV_HEADS
self.num_kv_groups = CP_NUM_HEADS // CP_NUM_KV_HEADS
self.scaling = CP_HEAD_DIM ** -0.5
self.q_proj = copy.deepcopy(original_attn.q_proj)
self.k_proj = copy.deepcopy(original_attn.k_proj)
self.v_proj = copy.deepcopy(original_attn.v_proj)
self.o_proj = copy.deepcopy(original_attn.o_proj)
self.q_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6)
self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight)
self.k_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6)
self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight)
def forward(self, hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask):
bsz, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim)
q = self.q_norm(q).transpose(1, 2)
k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim)
k = self.k_norm(k).transpose(1, 2)
v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
k_cache = k_cache.clone()
v_cache = v_cache.clone()
k_cache[:, :, cache_position, :] = k
v_cache[:, :, cache_position, :] = v
k_expanded = k_cache.unsqueeze(2).repeat(
1, 1, self.num_kv_groups, 1, 1
).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim)
v_expanded = v_cache.unsqueeze(2).repeat(
1, 1, self.num_kv_groups, 1, 1
).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim)
attn_output = F.scaled_dot_product_attention(
q, k_expanded, v_expanded,
attn_mask=attn_mask,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, k_cache, v_cache
class CPMLP(nn.Module):
def __init__(self, original_mlp):
super().__init__()
self.gate_proj = copy.deepcopy(original_mlp.gate_proj)
self.up_proj = copy.deepcopy(original_mlp.up_proj)
self.down_proj = copy.deepcopy(original_mlp.down_proj)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class CPLayerForExport(nn.Module):
def __init__(self, original_layer, layer_idx):
super().__init__()
self.attn = CPAttentionForExport(original_layer.self_attn, layer_idx)
self.mlp = CPMLP(original_layer.mlp)
self.input_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight)
self.post_attn_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight)
def forward(self, hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask):
residual = hidden_states
hidden_states = self.input_norm(hidden_states)
attn_out, k_cache, v_cache = self.attn(
hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.post_attn_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, k_cache, v_cache
class CodePredictorForExport(nn.Module):
"""
Export-ready code predictor backbone.
Input: pre-projected inputs_embeds (already through small_to_mtp_projection)
Output: hidden states (caller applies the appropriate lm_head externally)
For the full 16-codebook prediction:
1. Python builds inputs_embeds from talker hidden + codec embeddings
2. This module runs the transformer
3. Python takes hidden[:, step_idx, :] and applies lm_head[step_idx]
"""
def __init__(self, original_cp):
super().__init__()
# Transformer layers
self.layers = nn.ModuleList()
for i, layer in enumerate(original_cp.model.layers):
self.layers.append(CPLayerForExport(layer, i))
# Final norm
self.norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
self.norm.weight = copy.deepcopy(original_cp.model.norm.weight)
# Projection from talker hidden to code predictor hidden
self.small_to_mtp_projection = copy.deepcopy(original_cp.small_to_mtp_projection)
# LM heads (15 heads, one per code group 1..15)
self.lm_heads = nn.ModuleList()
for head in original_cp.lm_head:
self.lm_heads.append(copy.deepcopy(head))
# Rotary embedding
orig_rope = original_cp.model.rotary_emb
self.register_buffer("inv_freq", orig_rope.inv_freq.clone())
self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0)
def _compute_rope(self, position_ids, device, dtype):
pos = position_ids.float() # [B, seq_len]
inv_freq = self.inv_freq.float().to(device)
freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)
emb = torch.cat([freqs, freqs], dim=-1)
cos = (emb.cos() * self.rope_scaling).to(dtype)
sin = (emb.sin() * self.rope_scaling).to(dtype)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, inputs_embeds, position_ids, cache_position, attn_mask,
*kv_cache_flat):
"""
Args:
inputs_embeds: [B, seq_len, talker_hidden_size] β NOT YET projected
position_ids: [B, seq_len]
cache_position: [seq_len]
attn_mask: [B, 1, seq_len, MAX_SEQ_LEN]
*kv_cache_flat: 5 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim]
Returns:
hidden_states: [B, seq_len, CP_HIDDEN_SIZE] β apply lm_head externally
*updated_kv_cache
"""
# Project from talker hidden β code predictor hidden
hidden_states = self.small_to_mtp_projection(inputs_embeds)
cos, sin = self._compute_rope(position_ids, hidden_states.device, hidden_states.dtype)
updated_kv = []
for i, layer in enumerate(self.layers):
k_cache = kv_cache_flat[i * 2]
v_cache = kv_cache_flat[i * 2 + 1]
hidden_states, new_k, new_v = layer(
hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask
)
updated_kv.append(new_k)
updated_kv.append(new_v)
hidden_states = self.norm(hidden_states)
return (hidden_states, *updated_kv)
print(" Constructing CodePredictorForExport...")
t0 = time.time()
export_cp = CodePredictorForExport(model.talker.code_predictor)
export_cp.eval()
print(f" Done in {time.time() - t0:.1f}s")
param_count = sum(p.numel() for p in export_cp.parameters())
print(f" Parameters: {param_count / 1e6:.1f}M")
# ββ 3. Validate βββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[3/5] Validating wrapper...")
# Prefill: 2 tokens (projected_talker_hidden + first_codec_embed)
seq_len = 2
test_embeds = torch.randn(BATCH_SIZE, seq_len, TALKER_HIDDEN_SIZE)
test_pos = torch.arange(seq_len).unsqueeze(0).expand(BATCH_SIZE, -1)
test_cache_pos = torch.arange(seq_len)
causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf'))
for i in range(seq_len):
causal_mask[:, :, i, :i + 1] = 0.0
kv_cache = []
for _ in range(CP_NUM_LAYERS):
kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM))
kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM))
with torch.no_grad():
outputs = export_cp(test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache)
hidden = outputs[0]
print(f" Hidden states shape: {list(hidden.shape)}") # [1, 2, 1024]
assert hidden.shape == (BATCH_SIZE, seq_len, CP_HIDDEN_SIZE)
# Apply lm_head to get logits for the first prediction step
logits_0 = export_cp.lm_heads[0](hidden[:, -1:, :])
print(f" Logits[0] shape: {list(logits_0.shape)}") # [1, 1, 2048]
assert logits_0.shape[-1] == CP_VOCAB_SIZE
# Decode step
decode_embeds = torch.randn(BATCH_SIZE, 1, TALKER_HIDDEN_SIZE)
decode_pos = torch.tensor([[seq_len]])
decode_cache_pos = torch.tensor([seq_len])
decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf'))
decode_mask[:, :, :, :seq_len + 1] = 0.0
updated_kv = list(outputs[1:])
with torch.no_grad():
decode_out = export_cp(decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv)
print(f" Decode hidden shape: {list(decode_out[0].shape)}")
print(" PASS β code predictor validated")
# ββ 4. torch.export βββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[4/5] Running torch.export...")
t0 = time.time()
prefill_args = (test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache)
try:
exported = torch.export.export(export_cp, prefill_args, strict=False)
print(f" torch.export succeeded in {time.time() - t0:.1f}s")
print(f" Graph nodes: {len(exported.graph.nodes)}")
except Exception as e:
print(f" torch.export FAILED: {e}")
exported = None
# ββ 5. Lower to .pte ββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[5/5] Lowering to ExecuTorch .pte...")
t0 = time.time()
if exported is not None:
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(OUTPUT_DIR, "code_predictor.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size = os.path.getsize(pte_path) / 1e6
print(f" .pte saved: {pte_path}")
print(f" .pte size: {pte_size:.1f} MB")
print(f" Lowered in {time.time() - t0:.1f}s")
except Exception as e:
print(f" ExecuTorch lowering failed: {e}")
pt2_path = os.path.join(OUTPUT_DIR, "code_predictor.pt2")
torch.export.save(exported, pt2_path)
print(f" Saved: {pt2_path}")
# Also save the codec embeddings and lm_heads for the orchestration layer
torch.save({
"codec_embeddings": [emb.state_dict() for emb in model.talker.code_predictor.model.codec_embedding],
"lm_heads": [head.state_dict() for head in export_cp.lm_heads],
"small_to_mtp_projection": export_cp.small_to_mtp_projection.state_dict(),
}, os.path.join(OUTPUT_DIR, "code_predictor_extras.pt"))
print(f" Saved codec embeddings + lm_heads: {OUTPUT_DIR}/code_predictor_extras.pt")
print("\n" + "=" * 70)
print("Phase 4 complete!")
print(f" Max seq len: {MAX_SEQ_LEN}")
print(f" Parameters: {param_count / 1e6:.1f}M")
print(f" Vocab: {CP_VOCAB_SIZE}, Code groups: {CP_NUM_CODE_GROUPS}")
print("=" * 70)
|