| |
| """ |
| Phase 3: Export Talker (Main LM) to ExecuTorch .pte |
| ===================================================== |
| The talker is a 1.7B Qwen3 decoder with: |
| - 28 layers, hidden_size=2048, heads=16, kv_heads=8, head_dim=128 |
| - MROPE (3D rotary, sections [24,20,20], interleaved=True) |
| - QK-norm (RMSNorm on Q and K per head) |
| - Two embedding tables (text 151936 + codec 3072) |
| - codec_head Linear(2048 β 3072) |
| |
| We export the pure transformer backbone + codec_head as a standalone module. |
| The embedding interleaving and code_predictor calls happen in Python orchestration. |
| |
| Export format: prefill + decode share the same module with static KV cache. |
| """ |
|
|
| import sys |
| import os |
| import copy |
| import time |
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") |
| VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") |
| QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") |
| OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") |
|
|
| if VENV_SITE not in sys.path: |
| sys.path.insert(0, VENV_SITE) |
| if QWEN_TTS_SRC not in sys.path: |
| sys.path.insert(0, QWEN_TTS_SRC) |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| MAX_SEQ_LEN = 2048 |
| BATCH_SIZE = 1 |
| NUM_LAYERS = 28 |
| NUM_KV_HEADS = 8 |
| HEAD_DIM = 128 |
| NUM_HEADS = 16 |
| HIDDEN_SIZE = 2048 |
| INTERMEDIATE_SIZE = 6144 |
| CODEC_VOCAB = 3072 |
| MROPE_SECTIONS = [24, 20, 20] |
|
|
| print("=" * 70) |
| print("PHASE 3: Export Talker (Main LM) β .pte") |
| print("=" * 70) |
|
|
| |
|
|
| print("\n[1/6] Loading model...") |
| from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig |
| from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration |
|
|
| config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) |
| model = Qwen3TTSForConditionalGeneration.from_pretrained( |
| MODEL_PATH, config=config, dtype=torch.float32, |
| attn_implementation="sdpa", device_map="cpu", |
| ) |
| model.eval() |
| print(" Model loaded.") |
|
|
| |
|
|
| print("\n[2/6] Building export-ready talker wrapper...") |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Simple RMSNorm without kernel_forward_from_hub decorator.""" |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| x = x.float() |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.eps) |
| return (self.weight * x).to(dtype) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_mrope_interleaved(q, k, cos, sin, mrope_section): |
| """ |
| Apply MROPE with interleaved sections. |
| cos, sin: [3, B, seq_len, head_dim] |
| mrope_section: [24, 20, 20] (half-head dimensions per modality) |
| """ |
| |
| |
| |
| |
| |
| |
| modality_num = 3 |
| half_dim = cos.shape[-1] // 2 |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| cos_combined = cos[0].unsqueeze(1) |
| sin_combined = sin[0].unsqueeze(1) |
|
|
| q_embed = (q * cos_combined) + (rotate_half(q) * sin_combined) |
| k_embed = (k * cos_combined) + (rotate_half(k) * sin_combined) |
| return q_embed, k_embed |
|
|
|
|
| class TalkerAttentionForExport(nn.Module): |
| """Single attention layer with static KV cache and MROPE.""" |
|
|
| def __init__(self, original_attn, layer_idx): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.head_dim = HEAD_DIM |
| self.num_heads = NUM_HEADS |
| self.num_kv_heads = NUM_KV_HEADS |
| self.num_kv_groups = NUM_HEADS // NUM_KV_HEADS |
| self.scaling = HEAD_DIM ** -0.5 |
|
|
| |
| self.q_proj = copy.deepcopy(original_attn.q_proj) |
| self.k_proj = copy.deepcopy(original_attn.k_proj) |
| self.v_proj = copy.deepcopy(original_attn.v_proj) |
| self.o_proj = copy.deepcopy(original_attn.o_proj) |
|
|
| |
| self.q_norm = RMSNorm(HEAD_DIM, eps=1e-6) |
| self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight) |
| self.k_norm = RMSNorm(HEAD_DIM, eps=1e-6) |
| self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight) |
|
|
| def forward(self, hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask): |
| """ |
| Args: |
| hidden_states: [B, seq_len, hidden_size] |
| cos, sin: [B, 1, seq_len, head_dim] (already processed for MROPE) |
| cache_position: [seq_len] β indices to write into KV cache |
| k_cache: [B, num_kv_heads, max_seq_len, head_dim] |
| v_cache: [B, num_kv_heads, max_seq_len, head_dim] |
| attn_mask: [B, 1, seq_len, max_seq_len] |
| |
| Returns: |
| attn_output: [B, seq_len, hidden_size] |
| k_cache, v_cache: updated caches |
| """ |
| bsz, seq_len, _ = hidden_states.shape |
|
|
| |
| q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim) |
| q = self.q_norm(q).transpose(1, 2) |
|
|
| k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim) |
| k = self.k_norm(k).transpose(1, 2) |
|
|
| v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| q = (q * cos) + (rotate_half(q) * sin) |
| k = (k * cos) + (rotate_half(k) * sin) |
|
|
| |
| |
| k_cache = k_cache.clone() |
| v_cache = v_cache.clone() |
| k_cache[:, :, cache_position, :] = k |
| v_cache[:, :, cache_position, :] = v |
|
|
| |
| k_expanded = k_cache.unsqueeze(2).repeat( |
| 1, 1, self.num_kv_groups, 1, 1 |
| ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) |
|
|
| v_expanded = v_cache.unsqueeze(2).repeat( |
| 1, 1, self.num_kv_groups, 1, 1 |
| ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) |
|
|
| |
| attn_output = F.scaled_dot_product_attention( |
| q, k_expanded, v_expanded, |
| attn_mask=attn_mask, |
| scale=self.scaling, |
| ) |
|
|
| |
| attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1) |
| attn_output = self.o_proj(attn_output) |
|
|
| return attn_output, k_cache, v_cache |
|
|
|
|
| class TalkerMLP(nn.Module): |
| def __init__(self, original_mlp): |
| super().__init__() |
| self.gate_proj = copy.deepcopy(original_mlp.gate_proj) |
| self.up_proj = copy.deepcopy(original_mlp.up_proj) |
| self.down_proj = copy.deepcopy(original_mlp.down_proj) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class TalkerLayerForExport(nn.Module): |
| def __init__(self, original_layer, layer_idx): |
| super().__init__() |
| self.attn = TalkerAttentionForExport(original_layer.self_attn, layer_idx) |
| self.mlp = TalkerMLP(original_layer.mlp) |
| self.input_norm = RMSNorm(HIDDEN_SIZE, eps=1e-6) |
| self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight) |
| self.post_attn_norm = RMSNorm(HIDDEN_SIZE, eps=1e-6) |
| self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight) |
|
|
| def forward(self, hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask): |
| |
| residual = hidden_states |
| hidden_states = self.input_norm(hidden_states) |
| attn_out, k_cache, v_cache = self.attn( |
| hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask |
| ) |
| hidden_states = residual + attn_out |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attn_norm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states, k_cache, v_cache |
|
|
|
|
| class TalkerForExport(nn.Module): |
| """ |
| Standalone talker backbone for ExecuTorch export. |
| |
| Takes pre-computed inputs_embeds (from the orchestration layer) and |
| returns codec logits + updated KV cache. |
| |
| The MROPE, DynamicCache, and transformers utilities are replaced with |
| simple, export-friendly equivalents. |
| """ |
|
|
| def __init__(self, original_talker): |
| super().__init__() |
|
|
| |
| self.layers = nn.ModuleList() |
| for i, layer in enumerate(original_talker.model.layers): |
| self.layers.append(TalkerLayerForExport(layer, i)) |
|
|
| |
| self.norm = RMSNorm(HIDDEN_SIZE, eps=1e-6) |
| self.norm.weight = copy.deepcopy(original_talker.model.norm.weight) |
|
|
| |
| self.codec_head = copy.deepcopy(original_talker.codec_head) |
|
|
| |
| |
| orig_rope = original_talker.model.rotary_emb |
| self.register_buffer("inv_freq", orig_rope.inv_freq.clone()) |
| self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0) |
|
|
| def _compute_rope(self, position_ids, device, dtype): |
| """ |
| Compute MROPE cos/sin for the given position_ids. |
| |
| For TTS, position_ids shape is [3, B, seq_len] but all 3 dims are |
| identical, so we just use dim 0. |
| |
| Returns: |
| cos, sin: [B, 1, seq_len, head_dim] β ready for broadcasting |
| """ |
| |
| |
| pos = position_ids[0].float() |
|
|
| |
| inv_freq = self.inv_freq.float().to(device) |
|
|
| |
| |
| |
| freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) |
|
|
| |
| emb = torch.cat([freqs, freqs], dim=-1) |
|
|
| cos = (emb.cos() * self.rope_scaling).to(dtype) |
| sin = (emb.sin() * self.rope_scaling).to(dtype) |
|
|
| |
| return cos.unsqueeze(1), sin.unsqueeze(1) |
|
|
| def forward(self, inputs_embeds, position_ids, cache_position, attn_mask, |
| *kv_cache_flat): |
| """ |
| Args: |
| inputs_embeds: [B, seq_len, 2048] |
| position_ids: [3, B, seq_len] β MROPE positions (all 3 identical for TTS) |
| cache_position: [seq_len] β indices into KV cache |
| attn_mask: [B, 1, seq_len, MAX_SEQ_LEN] β causal attention mask |
| *kv_cache_flat: 28 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim] |
| Ordered as: k0, v0, k1, v1, ..., k27, v27 |
| |
| Returns: |
| logits: [B, seq_len, 3072] |
| *updated_kv_cache: same layout as input |
| """ |
| |
| cos, sin = self._compute_rope(position_ids, inputs_embeds.device, inputs_embeds.dtype) |
|
|
| hidden_states = inputs_embeds |
| updated_kv = [] |
|
|
| for i, layer in enumerate(self.layers): |
| k_cache = kv_cache_flat[i * 2] |
| v_cache = kv_cache_flat[i * 2 + 1] |
|
|
| hidden_states, new_k, new_v = layer( |
| hidden_states, cos, sin, cache_position, |
| k_cache, v_cache, attn_mask |
| ) |
| updated_kv.append(new_k) |
| updated_kv.append(new_v) |
|
|
| hidden_states = self.norm(hidden_states) |
| logits = self.codec_head(hidden_states) |
|
|
| return (logits, *updated_kv) |
|
|
|
|
| |
| print(" Constructing TalkerForExport...") |
| t0 = time.time() |
| export_talker = TalkerForExport(model.talker) |
| export_talker.eval() |
| print(f" Done in {time.time() - t0:.1f}s") |
|
|
| param_count = sum(p.numel() for p in export_talker.parameters()) |
| print(f" Parameters: {param_count / 1e9:.2f}B") |
|
|
| |
|
|
| print("\n[3/6] Validating wrapper vs original (single forward pass)...") |
|
|
| |
| seq_len = 10 |
| test_embeds = torch.randn(BATCH_SIZE, seq_len, HIDDEN_SIZE) |
| test_position_ids = torch.arange(seq_len).unsqueeze(0).unsqueeze(0).repeat(3, BATCH_SIZE, 1) |
| test_cache_position = torch.arange(seq_len) |
|
|
| |
| |
| causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf')) |
| for i in range(seq_len): |
| causal_mask[:, :, i, :i + 1] = 0.0 |
|
|
| |
| kv_cache = [] |
| for _ in range(NUM_LAYERS): |
| kv_cache.append(torch.zeros(BATCH_SIZE, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM)) |
| kv_cache.append(torch.zeros(BATCH_SIZE, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM)) |
|
|
| with torch.no_grad(): |
| outputs = export_talker(test_embeds, test_position_ids, test_cache_position, |
| causal_mask, *kv_cache) |
|
|
| logits = outputs[0] |
| print(f" Logits shape: {list(logits.shape)}") |
| assert logits.shape == (BATCH_SIZE, seq_len, CODEC_VOCAB), f"Unexpected shape: {logits.shape}" |
| print(f" KV cache outputs: {len(outputs) - 1} tensors") |
| assert len(outputs) - 1 == NUM_LAYERS * 2, f"Expected {NUM_LAYERS * 2} KV tensors" |
| print(" PASS β shapes correct") |
|
|
| |
| decode_embeds = torch.randn(BATCH_SIZE, 1, HIDDEN_SIZE) |
| decode_pos = torch.tensor([[[seq_len]]]).repeat(3, BATCH_SIZE, 1) |
| decode_cache_pos = torch.tensor([seq_len]) |
|
|
| |
| decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf')) |
| decode_mask[:, :, :, :seq_len + 1] = 0.0 |
|
|
| |
| updated_kv = list(outputs[1:]) |
| with torch.no_grad(): |
| decode_out = export_talker(decode_embeds, decode_pos, decode_cache_pos, |
| decode_mask, *updated_kv) |
|
|
| decode_logits = decode_out[0] |
| print(f" Decode logits shape: {list(decode_logits.shape)}") |
| assert decode_logits.shape == (BATCH_SIZE, 1, CODEC_VOCAB) |
| print(" PASS β decode step works") |
|
|
| |
|
|
| print("\n[4/6] Running torch.export (prefill, seq_len=10)...") |
| t0 = time.time() |
|
|
| |
| prefill_args = ( |
| test_embeds, |
| test_position_ids, |
| test_cache_position, |
| causal_mask, |
| *kv_cache, |
| ) |
|
|
| try: |
| exported_prefill = torch.export.export( |
| export_talker, |
| prefill_args, |
| strict=False, |
| ) |
| print(f" torch.export (prefill) succeeded in {time.time() - t0:.1f}s") |
| print(f" Graph has {len(exported_prefill.graph.nodes)} nodes") |
| except Exception as e: |
| print(f" torch.export (prefill) FAILED: {e}") |
| print(" This is expected for large models. Saving state_dict instead.") |
| torch.save(export_talker.state_dict(), os.path.join(OUTPUT_DIR, "talker_state_dict.pt")) |
| print(f" State dict saved to {OUTPUT_DIR}/talker_state_dict.pt") |
|
|
| |
| print("\n Trying torch.export with decode step (seq_len=1)...") |
| decode_args = ( |
| decode_embeds, |
| decode_pos, |
| decode_cache_pos, |
| decode_mask, |
| *updated_kv, |
| ) |
| try: |
| exported_decode = torch.export.export( |
| export_talker, |
| decode_args, |
| strict=False, |
| ) |
| print(f" torch.export (decode) succeeded in {time.time() - t0:.1f}s") |
| exported_prefill = None |
| except Exception as e2: |
| print(f" torch.export (decode) also FAILED: {e2}") |
| exported_prefill = None |
| exported_decode = None |
|
|
| |
|
|
| print("\n[5/6] Lowering to ExecuTorch .pte...") |
| t0 = time.time() |
|
|
| |
| ep = exported_prefill |
| label = "prefill" |
|
|
| if ep is None: |
| if 'exported_decode' in dir() and exported_decode is not None: |
| ep = exported_decode |
| label = "decode" |
| else: |
| print(" No exported program available. Skipping .pte generation.") |
| print(" State dict saved β can be loaded for on-device inference.") |
| ep = None |
|
|
| if ep is not None: |
| try: |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| edge = to_edge_transform_and_lower( |
| ep, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
|
|
| et_program = edge.to_executorch() |
|
|
| pte_path = os.path.join(OUTPUT_DIR, f"talker_{label}.pte") |
| with open(pte_path, "wb") as f: |
| f.write(et_program.buffer) |
|
|
| pte_size = os.path.getsize(pte_path) / 1e6 |
| print(f" .pte saved: {pte_path}") |
| print(f" .pte size: {pte_size:.1f} MB") |
| print(f" Lowered in {time.time() - t0:.1f}s") |
|
|
| except Exception as e: |
| print(f" ExecuTorch lowering failed: {e}") |
| pt2_path = os.path.join(OUTPUT_DIR, f"talker_{label}.pt2") |
| torch.export.save(ep, pt2_path) |
| print(f" Saved exported program: {pt2_path}") |
|
|
| |
|
|
| print("\n[6/6] Also saving embedding tables and text_projection for orchestration...") |
|
|
| |
| |
| torch.save({ |
| "text_embedding": model.talker.model.text_embedding.state_dict(), |
| "codec_embedding": model.talker.model.codec_embedding.state_dict(), |
| "text_projection": model.talker.text_projection.state_dict(), |
| }, os.path.join(OUTPUT_DIR, "talker_embeddings.pt")) |
| print(f" Saved: {OUTPUT_DIR}/talker_embeddings.pt") |
|
|
| print("\n" + "=" * 70) |
| print("Phase 3 complete!") |
| print(f" Max seq len: {MAX_SEQ_LEN}") |
| print(f" KV cache per layer: 2 Γ [1, {NUM_KV_HEADS}, {MAX_SEQ_LEN}, {HEAD_DIM}]") |
| print(f" Total KV cache (fp32): {2 * NUM_LAYERS * NUM_KV_HEADS * MAX_SEQ_LEN * HEAD_DIM * 4 / 1e6:.0f} MB") |
| print(f" Codec vocab: {CODEC_VOCAB}") |
| print("=" * 70) |
|
|