| |
| """ |
| Phase 4: Export Code Predictor to ExecuTorch .pte |
| ================================================== |
| The code predictor is a smaller 5-layer transformer (175M params) that |
| takes the talker's hidden state + first codebook token and autoregressively |
| generates the remaining 15 codebook tokens. |
| |
| Architecture: |
| - hidden_size=1024, 5 layers, 16 heads, 8 kv_heads, head_dim=128 |
| - small_to_mtp_projection: Linear(2048β1024) β projects talker hidden β predictor |
| - 15 lm_heads: Linear(1024β2048) each (one per code group) |
| - 15 codec_embeddings: Embedding(2048, 2048) each |
| |
| During inference (called once per talker decode step): |
| Step 0 (prefill): concat(projected_talker_hidden, codec_embed_0(first_token)) β 2 tokens |
| Steps 1-14: predict next code group token β embed it β feed back |
| |
| We export this as a static-KV-cache transformer similar to the talker. |
| """ |
|
|
| import sys |
| import os |
| import copy |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") |
| VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") |
| QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") |
| OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") |
|
|
| if VENV_SITE not in sys.path: |
| sys.path.insert(0, VENV_SITE) |
| if QWEN_TTS_SRC not in sys.path: |
| sys.path.insert(0, QWEN_TTS_SRC) |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| MAX_SEQ_LEN = 17 |
| BATCH_SIZE = 1 |
| CP_NUM_LAYERS = 5 |
| CP_NUM_KV_HEADS = 8 |
| CP_HEAD_DIM = 128 |
| CP_NUM_HEADS = 16 |
| CP_HIDDEN_SIZE = 1024 |
| CP_INTERMEDIATE_SIZE = 3072 |
| CP_VOCAB_SIZE = 2048 |
| CP_NUM_CODE_GROUPS = 16 |
| TALKER_HIDDEN_SIZE = 2048 |
|
|
| print("=" * 70) |
| print("PHASE 4: Export Code Predictor β .pte") |
| print("=" * 70) |
|
|
| |
|
|
| print("\n[1/5] Loading model...") |
| from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig |
| from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration |
|
|
| config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) |
| model = Qwen3TTSForConditionalGeneration.from_pretrained( |
| MODEL_PATH, config=config, dtype=torch.float32, |
| attn_implementation="sdpa", device_map="cpu", |
| ) |
| model.eval() |
| print(" Model loaded.") |
|
|
| |
|
|
| print("\n[2/5] Building export-ready code predictor wrapper...") |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| x = x.float() |
| v = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(v + self.eps) |
| return (self.weight * x).to(dtype) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| class CPAttentionForExport(nn.Module): |
| """Code predictor attention layer with static KV cache.""" |
|
|
| def __init__(self, original_attn, layer_idx): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.head_dim = CP_HEAD_DIM |
| self.num_heads = CP_NUM_HEADS |
| self.num_kv_heads = CP_NUM_KV_HEADS |
| self.num_kv_groups = CP_NUM_HEADS // CP_NUM_KV_HEADS |
| self.scaling = CP_HEAD_DIM ** -0.5 |
|
|
| self.q_proj = copy.deepcopy(original_attn.q_proj) |
| self.k_proj = copy.deepcopy(original_attn.k_proj) |
| self.v_proj = copy.deepcopy(original_attn.v_proj) |
| self.o_proj = copy.deepcopy(original_attn.o_proj) |
| self.q_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6) |
| self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight) |
| self.k_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6) |
| self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight) |
|
|
| def forward(self, hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask): |
| bsz, seq_len, _ = hidden_states.shape |
|
|
| q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim) |
| q = self.q_norm(q).transpose(1, 2) |
| k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim) |
| k = self.k_norm(k).transpose(1, 2) |
| v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| q = (q * cos) + (rotate_half(q) * sin) |
| k = (k * cos) + (rotate_half(k) * sin) |
|
|
| k_cache = k_cache.clone() |
| v_cache = v_cache.clone() |
| k_cache[:, :, cache_position, :] = k |
| v_cache[:, :, cache_position, :] = v |
|
|
| k_expanded = k_cache.unsqueeze(2).repeat( |
| 1, 1, self.num_kv_groups, 1, 1 |
| ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) |
| v_expanded = v_cache.unsqueeze(2).repeat( |
| 1, 1, self.num_kv_groups, 1, 1 |
| ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) |
|
|
| attn_output = F.scaled_dot_product_attention( |
| q, k_expanded, v_expanded, |
| attn_mask=attn_mask, |
| scale=self.scaling, |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1) |
| attn_output = self.o_proj(attn_output) |
| return attn_output, k_cache, v_cache |
|
|
|
|
| class CPMLP(nn.Module): |
| def __init__(self, original_mlp): |
| super().__init__() |
| self.gate_proj = copy.deepcopy(original_mlp.gate_proj) |
| self.up_proj = copy.deepcopy(original_mlp.up_proj) |
| self.down_proj = copy.deepcopy(original_mlp.down_proj) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class CPLayerForExport(nn.Module): |
| def __init__(self, original_layer, layer_idx): |
| super().__init__() |
| self.attn = CPAttentionForExport(original_layer.self_attn, layer_idx) |
| self.mlp = CPMLP(original_layer.mlp) |
| self.input_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6) |
| self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight) |
| self.post_attn_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6) |
| self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight) |
|
|
| def forward(self, hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask): |
| residual = hidden_states |
| hidden_states = self.input_norm(hidden_states) |
| attn_out, k_cache, v_cache = self.attn( |
| hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask |
| ) |
| hidden_states = residual + attn_out |
|
|
| residual = hidden_states |
| hidden_states = self.post_attn_norm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states, k_cache, v_cache |
|
|
|
|
| class CodePredictorForExport(nn.Module): |
| """ |
| Export-ready code predictor backbone. |
| |
| Input: pre-projected inputs_embeds (already through small_to_mtp_projection) |
| Output: hidden states (caller applies the appropriate lm_head externally) |
| |
| For the full 16-codebook prediction: |
| 1. Python builds inputs_embeds from talker hidden + codec embeddings |
| 2. This module runs the transformer |
| 3. Python takes hidden[:, step_idx, :] and applies lm_head[step_idx] |
| """ |
|
|
| def __init__(self, original_cp): |
| super().__init__() |
|
|
| |
| self.layers = nn.ModuleList() |
| for i, layer in enumerate(original_cp.model.layers): |
| self.layers.append(CPLayerForExport(layer, i)) |
|
|
| |
| self.norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6) |
| self.norm.weight = copy.deepcopy(original_cp.model.norm.weight) |
|
|
| |
| self.small_to_mtp_projection = copy.deepcopy(original_cp.small_to_mtp_projection) |
|
|
| |
| self.lm_heads = nn.ModuleList() |
| for head in original_cp.lm_head: |
| self.lm_heads.append(copy.deepcopy(head)) |
|
|
| |
| orig_rope = original_cp.model.rotary_emb |
| self.register_buffer("inv_freq", orig_rope.inv_freq.clone()) |
| self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0) |
|
|
| def _compute_rope(self, position_ids, device, dtype): |
| pos = position_ids.float() |
| inv_freq = self.inv_freq.float().to(device) |
| freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = (emb.cos() * self.rope_scaling).to(dtype) |
| sin = (emb.sin() * self.rope_scaling).to(dtype) |
| return cos.unsqueeze(1), sin.unsqueeze(1) |
|
|
| def forward(self, inputs_embeds, position_ids, cache_position, attn_mask, |
| *kv_cache_flat): |
| """ |
| Args: |
| inputs_embeds: [B, seq_len, talker_hidden_size] β NOT YET projected |
| position_ids: [B, seq_len] |
| cache_position: [seq_len] |
| attn_mask: [B, 1, seq_len, MAX_SEQ_LEN] |
| *kv_cache_flat: 5 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim] |
| |
| Returns: |
| hidden_states: [B, seq_len, CP_HIDDEN_SIZE] β apply lm_head externally |
| *updated_kv_cache |
| """ |
| |
| hidden_states = self.small_to_mtp_projection(inputs_embeds) |
|
|
| cos, sin = self._compute_rope(position_ids, hidden_states.device, hidden_states.dtype) |
|
|
| updated_kv = [] |
| for i, layer in enumerate(self.layers): |
| k_cache = kv_cache_flat[i * 2] |
| v_cache = kv_cache_flat[i * 2 + 1] |
| hidden_states, new_k, new_v = layer( |
| hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask |
| ) |
| updated_kv.append(new_k) |
| updated_kv.append(new_v) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| return (hidden_states, *updated_kv) |
|
|
|
|
| print(" Constructing CodePredictorForExport...") |
| t0 = time.time() |
| export_cp = CodePredictorForExport(model.talker.code_predictor) |
| export_cp.eval() |
| print(f" Done in {time.time() - t0:.1f}s") |
|
|
| param_count = sum(p.numel() for p in export_cp.parameters()) |
| print(f" Parameters: {param_count / 1e6:.1f}M") |
|
|
| |
|
|
| print("\n[3/5] Validating wrapper...") |
|
|
| |
| seq_len = 2 |
| test_embeds = torch.randn(BATCH_SIZE, seq_len, TALKER_HIDDEN_SIZE) |
| test_pos = torch.arange(seq_len).unsqueeze(0).expand(BATCH_SIZE, -1) |
| test_cache_pos = torch.arange(seq_len) |
|
|
| causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf')) |
| for i in range(seq_len): |
| causal_mask[:, :, i, :i + 1] = 0.0 |
|
|
| kv_cache = [] |
| for _ in range(CP_NUM_LAYERS): |
| kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM)) |
| kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM)) |
|
|
| with torch.no_grad(): |
| outputs = export_cp(test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache) |
|
|
| hidden = outputs[0] |
| print(f" Hidden states shape: {list(hidden.shape)}") |
| assert hidden.shape == (BATCH_SIZE, seq_len, CP_HIDDEN_SIZE) |
|
|
| |
| logits_0 = export_cp.lm_heads[0](hidden[:, -1:, :]) |
| print(f" Logits[0] shape: {list(logits_0.shape)}") |
| assert logits_0.shape[-1] == CP_VOCAB_SIZE |
|
|
| |
| decode_embeds = torch.randn(BATCH_SIZE, 1, TALKER_HIDDEN_SIZE) |
| decode_pos = torch.tensor([[seq_len]]) |
| decode_cache_pos = torch.tensor([seq_len]) |
| decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf')) |
| decode_mask[:, :, :, :seq_len + 1] = 0.0 |
|
|
| updated_kv = list(outputs[1:]) |
| with torch.no_grad(): |
| decode_out = export_cp(decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv) |
|
|
| print(f" Decode hidden shape: {list(decode_out[0].shape)}") |
| print(" PASS β code predictor validated") |
|
|
| |
|
|
| print("\n[4/5] Running torch.export...") |
| t0 = time.time() |
|
|
| prefill_args = (test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache) |
|
|
| try: |
| exported = torch.export.export(export_cp, prefill_args, strict=False) |
| print(f" torch.export succeeded in {time.time() - t0:.1f}s") |
| print(f" Graph nodes: {len(exported.graph.nodes)}") |
| except Exception as e: |
| print(f" torch.export FAILED: {e}") |
| exported = None |
|
|
| |
|
|
| print("\n[5/5] Lowering to ExecuTorch .pte...") |
| t0 = time.time() |
|
|
| if exported is not None: |
| try: |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| edge = to_edge_transform_and_lower( |
| exported, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
| et_program = edge.to_executorch() |
|
|
| pte_path = os.path.join(OUTPUT_DIR, "code_predictor.pte") |
| with open(pte_path, "wb") as f: |
| f.write(et_program.buffer) |
|
|
| pte_size = os.path.getsize(pte_path) / 1e6 |
| print(f" .pte saved: {pte_path}") |
| print(f" .pte size: {pte_size:.1f} MB") |
| print(f" Lowered in {time.time() - t0:.1f}s") |
|
|
| except Exception as e: |
| print(f" ExecuTorch lowering failed: {e}") |
| pt2_path = os.path.join(OUTPUT_DIR, "code_predictor.pt2") |
| torch.export.save(exported, pt2_path) |
| print(f" Saved: {pt2_path}") |
|
|
| |
| torch.save({ |
| "codec_embeddings": [emb.state_dict() for emb in model.talker.code_predictor.model.codec_embedding], |
| "lm_heads": [head.state_dict() for head in export_cp.lm_heads], |
| "small_to_mtp_projection": export_cp.small_to_mtp_projection.state_dict(), |
| }, os.path.join(OUTPUT_DIR, "code_predictor_extras.pt")) |
| print(f" Saved codec embeddings + lm_heads: {OUTPUT_DIR}/code_predictor_extras.pt") |
|
|
| print("\n" + "=" * 70) |
| print("Phase 4 complete!") |
| print(f" Max seq len: {MAX_SEQ_LEN}") |
| print(f" Parameters: {param_count / 1e6:.1f}M") |
| print(f" Vocab: {CP_VOCAB_SIZE}, Code groups: {CP_NUM_CODE_GROUPS}") |
| print("=" * 70) |
|
|