#!/usr/bin/env python3 """ Phase 4: Export Code Predictor to ExecuTorch .pte ================================================== The code predictor is a smaller 5-layer transformer (175M params) that takes the talker's hidden state + first codebook token and autoregressively generates the remaining 15 codebook tokens. Architecture: - hidden_size=1024, 5 layers, 16 heads, 8 kv_heads, head_dim=128 - small_to_mtp_projection: Linear(2048→1024) — projects talker hidden → predictor - 15 lm_heads: Linear(1024→2048) each (one per code group) - 15 codec_embeddings: Embedding(2048, 2048) each During inference (called once per talker decode step): Step 0 (prefill): concat(projected_talker_hidden, codec_embed_0(first_token)) → 2 tokens Steps 1-14: predict next code group token → embed it → feed back We export this as a static-KV-cache transformer similar to the talker. """ import sys import os import copy import time import torch import torch.nn as nn import torch.nn.functional as F MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") if VENV_SITE not in sys.path: sys.path.insert(0, VENV_SITE) if QWEN_TTS_SRC not in sys.path: sys.path.insert(0, QWEN_TTS_SRC) os.makedirs(OUTPUT_DIR, exist_ok=True) # ── Configuration ──────────────────────────────────────────────────── MAX_SEQ_LEN = 17 # prefill=2, then 15 decode steps BATCH_SIZE = 1 CP_NUM_LAYERS = 5 CP_NUM_KV_HEADS = 8 CP_HEAD_DIM = 128 CP_NUM_HEADS = 16 CP_HIDDEN_SIZE = 1024 CP_INTERMEDIATE_SIZE = 3072 CP_VOCAB_SIZE = 2048 CP_NUM_CODE_GROUPS = 16 # total groups (predict 15, first comes from talker) TALKER_HIDDEN_SIZE = 2048 print("=" * 70) print("PHASE 4: Export Code Predictor → .pte") print("=" * 70) # ── 1. Load Model ─────────────────────────────────────────────────── print("\n[1/5] Loading model...") from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) model = Qwen3TTSForConditionalGeneration.from_pretrained( MODEL_PATH, config=config, dtype=torch.float32, attn_implementation="sdpa", device_map="cpu", ) model.eval() print(" Model loaded.") # ── 2. Build Export-Ready Code Predictor ───────────────────────────── print("\n[2/5] Building export-ready code predictor wrapper...") class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): dtype = x.dtype x = x.float() v = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(v + self.eps) return (self.weight * x).to(dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class CPAttentionForExport(nn.Module): """Code predictor attention layer with static KV cache.""" def __init__(self, original_attn, layer_idx): super().__init__() self.layer_idx = layer_idx self.head_dim = CP_HEAD_DIM self.num_heads = CP_NUM_HEADS self.num_kv_heads = CP_NUM_KV_HEADS self.num_kv_groups = CP_NUM_HEADS // CP_NUM_KV_HEADS self.scaling = CP_HEAD_DIM ** -0.5 self.q_proj = copy.deepcopy(original_attn.q_proj) self.k_proj = copy.deepcopy(original_attn.k_proj) self.v_proj = copy.deepcopy(original_attn.v_proj) self.o_proj = copy.deepcopy(original_attn.o_proj) self.q_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6) self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight) self.k_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6) self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight) def forward(self, hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask): bsz, seq_len, _ = hidden_states.shape q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim) q = self.q_norm(q).transpose(1, 2) k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim) k = self.k_norm(k).transpose(1, 2) v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) k_cache = k_cache.clone() v_cache = v_cache.clone() k_cache[:, :, cache_position, :] = k v_cache[:, :, cache_position, :] = v k_expanded = k_cache.unsqueeze(2).repeat( 1, 1, self.num_kv_groups, 1, 1 ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) v_expanded = v_cache.unsqueeze(2).repeat( 1, 1, self.num_kv_groups, 1, 1 ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) attn_output = F.scaled_dot_product_attention( q, k_expanded, v_expanded, attn_mask=attn_mask, scale=self.scaling, ) attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1) attn_output = self.o_proj(attn_output) return attn_output, k_cache, v_cache class CPMLP(nn.Module): def __init__(self, original_mlp): super().__init__() self.gate_proj = copy.deepcopy(original_mlp.gate_proj) self.up_proj = copy.deepcopy(original_mlp.up_proj) self.down_proj = copy.deepcopy(original_mlp.down_proj) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class CPLayerForExport(nn.Module): def __init__(self, original_layer, layer_idx): super().__init__() self.attn = CPAttentionForExport(original_layer.self_attn, layer_idx) self.mlp = CPMLP(original_layer.mlp) self.input_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6) self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight) self.post_attn_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6) self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight) def forward(self, hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask): residual = hidden_states hidden_states = self.input_norm(hidden_states) attn_out, k_cache, v_cache = self.attn( hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask ) hidden_states = residual + attn_out residual = hidden_states hidden_states = self.post_attn_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, k_cache, v_cache class CodePredictorForExport(nn.Module): """ Export-ready code predictor backbone. Input: pre-projected inputs_embeds (already through small_to_mtp_projection) Output: hidden states (caller applies the appropriate lm_head externally) For the full 16-codebook prediction: 1. Python builds inputs_embeds from talker hidden + codec embeddings 2. This module runs the transformer 3. Python takes hidden[:, step_idx, :] and applies lm_head[step_idx] """ def __init__(self, original_cp): super().__init__() # Transformer layers self.layers = nn.ModuleList() for i, layer in enumerate(original_cp.model.layers): self.layers.append(CPLayerForExport(layer, i)) # Final norm self.norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6) self.norm.weight = copy.deepcopy(original_cp.model.norm.weight) # Projection from talker hidden to code predictor hidden self.small_to_mtp_projection = copy.deepcopy(original_cp.small_to_mtp_projection) # LM heads (15 heads, one per code group 1..15) self.lm_heads = nn.ModuleList() for head in original_cp.lm_head: self.lm_heads.append(copy.deepcopy(head)) # Rotary embedding orig_rope = original_cp.model.rotary_emb self.register_buffer("inv_freq", orig_rope.inv_freq.clone()) self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0) def _compute_rope(self, position_ids, device, dtype): pos = position_ids.float() # [B, seq_len] inv_freq = self.inv_freq.float().to(device) freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) emb = torch.cat([freqs, freqs], dim=-1) cos = (emb.cos() * self.rope_scaling).to(dtype) sin = (emb.sin() * self.rope_scaling).to(dtype) return cos.unsqueeze(1), sin.unsqueeze(1) def forward(self, inputs_embeds, position_ids, cache_position, attn_mask, *kv_cache_flat): """ Args: inputs_embeds: [B, seq_len, talker_hidden_size] — NOT YET projected position_ids: [B, seq_len] cache_position: [seq_len] attn_mask: [B, 1, seq_len, MAX_SEQ_LEN] *kv_cache_flat: 5 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim] Returns: hidden_states: [B, seq_len, CP_HIDDEN_SIZE] — apply lm_head externally *updated_kv_cache """ # Project from talker hidden → code predictor hidden hidden_states = self.small_to_mtp_projection(inputs_embeds) cos, sin = self._compute_rope(position_ids, hidden_states.device, hidden_states.dtype) updated_kv = [] for i, layer in enumerate(self.layers): k_cache = kv_cache_flat[i * 2] v_cache = kv_cache_flat[i * 2 + 1] hidden_states, new_k, new_v = layer( hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask ) updated_kv.append(new_k) updated_kv.append(new_v) hidden_states = self.norm(hidden_states) return (hidden_states, *updated_kv) print(" Constructing CodePredictorForExport...") t0 = time.time() export_cp = CodePredictorForExport(model.talker.code_predictor) export_cp.eval() print(f" Done in {time.time() - t0:.1f}s") param_count = sum(p.numel() for p in export_cp.parameters()) print(f" Parameters: {param_count / 1e6:.1f}M") # ── 3. Validate ───────────────────────────────────────────────────── print("\n[3/5] Validating wrapper...") # Prefill: 2 tokens (projected_talker_hidden + first_codec_embed) seq_len = 2 test_embeds = torch.randn(BATCH_SIZE, seq_len, TALKER_HIDDEN_SIZE) test_pos = torch.arange(seq_len).unsqueeze(0).expand(BATCH_SIZE, -1) test_cache_pos = torch.arange(seq_len) causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf')) for i in range(seq_len): causal_mask[:, :, i, :i + 1] = 0.0 kv_cache = [] for _ in range(CP_NUM_LAYERS): kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM)) kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM)) with torch.no_grad(): outputs = export_cp(test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache) hidden = outputs[0] print(f" Hidden states shape: {list(hidden.shape)}") # [1, 2, 1024] assert hidden.shape == (BATCH_SIZE, seq_len, CP_HIDDEN_SIZE) # Apply lm_head to get logits for the first prediction step logits_0 = export_cp.lm_heads[0](hidden[:, -1:, :]) print(f" Logits[0] shape: {list(logits_0.shape)}") # [1, 1, 2048] assert logits_0.shape[-1] == CP_VOCAB_SIZE # Decode step decode_embeds = torch.randn(BATCH_SIZE, 1, TALKER_HIDDEN_SIZE) decode_pos = torch.tensor([[seq_len]]) decode_cache_pos = torch.tensor([seq_len]) decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf')) decode_mask[:, :, :, :seq_len + 1] = 0.0 updated_kv = list(outputs[1:]) with torch.no_grad(): decode_out = export_cp(decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv) print(f" Decode hidden shape: {list(decode_out[0].shape)}") print(" PASS — code predictor validated") # ── 4. torch.export ───────────────────────────────────────────────── print("\n[4/5] Running torch.export...") t0 = time.time() prefill_args = (test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache) try: exported = torch.export.export(export_cp, prefill_args, strict=False) print(f" torch.export succeeded in {time.time() - t0:.1f}s") print(f" Graph nodes: {len(exported.graph.nodes)}") except Exception as e: print(f" torch.export FAILED: {e}") exported = None # ── 5. Lower to .pte ──────────────────────────────────────────────── print("\n[5/5] Lowering to ExecuTorch .pte...") t0 = time.time() if exported is not None: try: from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner edge = to_edge_transform_and_lower( exported, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) et_program = edge.to_executorch() pte_path = os.path.join(OUTPUT_DIR, "code_predictor.pte") with open(pte_path, "wb") as f: f.write(et_program.buffer) pte_size = os.path.getsize(pte_path) / 1e6 print(f" .pte saved: {pte_path}") print(f" .pte size: {pte_size:.1f} MB") print(f" Lowered in {time.time() - t0:.1f}s") except Exception as e: print(f" ExecuTorch lowering failed: {e}") pt2_path = os.path.join(OUTPUT_DIR, "code_predictor.pt2") torch.export.save(exported, pt2_path) print(f" Saved: {pt2_path}") # Also save the codec embeddings and lm_heads for the orchestration layer torch.save({ "codec_embeddings": [emb.state_dict() for emb in model.talker.code_predictor.model.codec_embedding], "lm_heads": [head.state_dict() for head in export_cp.lm_heads], "small_to_mtp_projection": export_cp.small_to_mtp_projection.state_dict(), }, os.path.join(OUTPUT_DIR, "code_predictor_extras.pt")) print(f" Saved codec embeddings + lm_heads: {OUTPUT_DIR}/code_predictor_extras.pt") print("\n" + "=" * 70) print("Phase 4 complete!") print(f" Max seq len: {MAX_SEQ_LEN}") print(f" Parameters: {param_count / 1e6:.1f}M") print(f" Vocab: {CP_VOCAB_SIZE}, Code groups: {CP_NUM_CODE_GROUPS}") print("=" * 70)