#!/usr/bin/env python3 """ Phase 3: Export Talker (Main LM) to ExecuTorch .pte ===================================================== The talker is a 1.7B Qwen3 decoder with: - 28 layers, hidden_size=2048, heads=16, kv_heads=8, head_dim=128 - MROPE (3D rotary, sections [24,20,20], interleaved=True) - QK-norm (RMSNorm on Q and K per head) - Two embedding tables (text 151936 + codec 3072) - codec_head Linear(2048 → 3072) We export the pure transformer backbone + codec_head as a standalone module. The embedding interleaving and code_predictor calls happen in Python orchestration. Export format: prefill + decode share the same module with static KV cache. """ import sys import os import copy import time import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # ── paths ──────────────────────────────────────────────────────────── MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") if VENV_SITE not in sys.path: sys.path.insert(0, VENV_SITE) if QWEN_TTS_SRC not in sys.path: sys.path.insert(0, QWEN_TTS_SRC) os.makedirs(OUTPUT_DIR, exist_ok=True) # ── Configuration ──────────────────────────────────────────────────── MAX_SEQ_LEN = 2048 BATCH_SIZE = 1 NUM_LAYERS = 28 NUM_KV_HEADS = 8 HEAD_DIM = 128 NUM_HEADS = 16 HIDDEN_SIZE = 2048 INTERMEDIATE_SIZE = 6144 CODEC_VOCAB = 3072 MROPE_SECTIONS = [24, 20, 20] # must sum to head_dim/2 = 64 print("=" * 70) print("PHASE 3: Export Talker (Main LM) → .pte") print("=" * 70) # ── 1. Load Model ─────────────────────────────────────────────────── print("\n[1/6] Loading model...") from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) model = Qwen3TTSForConditionalGeneration.from_pretrained( MODEL_PATH, config=config, dtype=torch.float32, attn_implementation="sdpa", device_map="cpu", ) model.eval() print(" Model loaded.") # ── 2. Build Export-Ready Talker ───────────────────────────────────── print("\n[2/6] Building export-ready talker wrapper...") class RMSNorm(nn.Module): """Simple RMSNorm without kernel_forward_from_hub decorator.""" def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): dtype = x.dtype x = x.float() variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return (self.weight * x).to(dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_mrope_interleaved(q, k, cos, sin, mrope_section): """ Apply MROPE with interleaved sections. cos, sin: [3, B, seq_len, head_dim] mrope_section: [24, 20, 20] (half-head dimensions per modality) """ # For interleaved MROPE: # Build the combined cos/sin by interleaving the 3 modalities # Since position_ids are the same for all 3 in TTS, cos[0]==cos[1]==cos[2] # So we can simplify: just use cos[0], sin[0] directly # # But to be correct for the general case: modality_num = 3 half_dim = cos.shape[-1] // 2 # 64 # The interleaved approach: for each position in the half-dim, # assign it to modality (dim_idx % modality_num) # cos_combined[..., i] = cos[i % 3, ..., i] # Actually looking at the source more carefully, the interleaved code is: # x_t = x[0].clone() (start with modality 0) # then for modality 1 and 2, replace specific indices # The pattern: index beg_idx:end_idx:modality_num from cos[modality_idx] # Simpler approach: since all 3 position_ids are identical for TTS, # all cos[0]==cos[1]==cos[2] and sin[0]==sin[1]==sin[2] # So we just take cos[0] and sin[0] cos_combined = cos[0].unsqueeze(1) # [B, 1, seq_len, head_dim] sin_combined = sin[0].unsqueeze(1) # [B, 1, seq_len, head_dim] q_embed = (q * cos_combined) + (rotate_half(q) * sin_combined) k_embed = (k * cos_combined) + (rotate_half(k) * sin_combined) return q_embed, k_embed class TalkerAttentionForExport(nn.Module): """Single attention layer with static KV cache and MROPE.""" def __init__(self, original_attn, layer_idx): super().__init__() self.layer_idx = layer_idx self.head_dim = HEAD_DIM self.num_heads = NUM_HEADS self.num_kv_heads = NUM_KV_HEADS self.num_kv_groups = NUM_HEADS // NUM_KV_HEADS self.scaling = HEAD_DIM ** -0.5 # Copy weight matrices self.q_proj = copy.deepcopy(original_attn.q_proj) self.k_proj = copy.deepcopy(original_attn.k_proj) self.v_proj = copy.deepcopy(original_attn.v_proj) self.o_proj = copy.deepcopy(original_attn.o_proj) # QK-norm self.q_norm = RMSNorm(HEAD_DIM, eps=1e-6) self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight) self.k_norm = RMSNorm(HEAD_DIM, eps=1e-6) self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight) def forward(self, hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask): """ Args: hidden_states: [B, seq_len, hidden_size] cos, sin: [B, 1, seq_len, head_dim] (already processed for MROPE) cache_position: [seq_len] — indices to write into KV cache k_cache: [B, num_kv_heads, max_seq_len, head_dim] v_cache: [B, num_kv_heads, max_seq_len, head_dim] attn_mask: [B, 1, seq_len, max_seq_len] Returns: attn_output: [B, seq_len, hidden_size] k_cache, v_cache: updated caches """ bsz, seq_len, _ = hidden_states.shape # Project Q, K, V q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim) q = self.q_norm(q).transpose(1, 2) # [B, heads, seq, hd] k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim) k = self.k_norm(k).transpose(1, 2) # [B, kv_heads, seq, hd] v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # Apply rotary embeddings q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) # Update static KV cache # cache_position: indices where new K, V should be placed k_cache = k_cache.clone() v_cache = v_cache.clone() k_cache[:, :, cache_position, :] = k v_cache[:, :, cache_position, :] = v # Expand KV for GQA (use repeat instead of expand to avoid stride 0) k_expanded = k_cache.unsqueeze(2).repeat( 1, 1, self.num_kv_groups, 1, 1 ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) v_expanded = v_cache.unsqueeze(2).repeat( 1, 1, self.num_kv_groups, 1, 1 ).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim) # Scaled dot product attention attn_output = F.scaled_dot_product_attention( q, k_expanded, v_expanded, attn_mask=attn_mask, scale=self.scaling, ) # Reshape and project output attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1) attn_output = self.o_proj(attn_output) return attn_output, k_cache, v_cache class TalkerMLP(nn.Module): def __init__(self, original_mlp): super().__init__() self.gate_proj = copy.deepcopy(original_mlp.gate_proj) self.up_proj = copy.deepcopy(original_mlp.up_proj) self.down_proj = copy.deepcopy(original_mlp.down_proj) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class TalkerLayerForExport(nn.Module): def __init__(self, original_layer, layer_idx): super().__init__() self.attn = TalkerAttentionForExport(original_layer.self_attn, layer_idx) self.mlp = TalkerMLP(original_layer.mlp) self.input_norm = RMSNorm(HIDDEN_SIZE, eps=1e-6) self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight) self.post_attn_norm = RMSNorm(HIDDEN_SIZE, eps=1e-6) self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight) def forward(self, hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask): # Self attention residual = hidden_states hidden_states = self.input_norm(hidden_states) attn_out, k_cache, v_cache = self.attn( hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask ) hidden_states = residual + attn_out # MLP residual = hidden_states hidden_states = self.post_attn_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, k_cache, v_cache class TalkerForExport(nn.Module): """ Standalone talker backbone for ExecuTorch export. Takes pre-computed inputs_embeds (from the orchestration layer) and returns codec logits + updated KV cache. The MROPE, DynamicCache, and transformers utilities are replaced with simple, export-friendly equivalents. """ def __init__(self, original_talker): super().__init__() # Transformer layers self.layers = nn.ModuleList() for i, layer in enumerate(original_talker.model.layers): self.layers.append(TalkerLayerForExport(layer, i)) # Final norm self.norm = RMSNorm(HIDDEN_SIZE, eps=1e-6) self.norm.weight = copy.deepcopy(original_talker.model.norm.weight) # Codec head self.codec_head = copy.deepcopy(original_talker.codec_head) # Rotary embedding: precompute inv_freq # Original: inv_freq from Qwen3TTSTalkerRotaryEmbedding orig_rope = original_talker.model.rotary_emb self.register_buffer("inv_freq", orig_rope.inv_freq.clone()) self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0) def _compute_rope(self, position_ids, device, dtype): """ Compute MROPE cos/sin for the given position_ids. For TTS, position_ids shape is [3, B, seq_len] but all 3 dims are identical, so we just use dim 0. Returns: cos, sin: [B, 1, seq_len, head_dim] — ready for broadcasting """ # position_ids: [3, B, seq_len] # For TTS all 3 are identical, use [0]: [B, seq_len] pos = position_ids[0].float() # [B, seq_len] # inv_freq: [head_dim // 2] (64 values) inv_freq = self.inv_freq.float().to(device) # freqs: [B, seq_len, head_dim//2] # pos: [B, seq_len] → [B, seq_len, 1] # inv_freq: [head_dim//2] → [1, 1, head_dim//2] freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # emb: [B, seq_len, head_dim] emb = torch.cat([freqs, freqs], dim=-1) cos = (emb.cos() * self.rope_scaling).to(dtype) sin = (emb.sin() * self.rope_scaling).to(dtype) # Add head dimension for broadcasting: [B, 1, seq_len, head_dim] return cos.unsqueeze(1), sin.unsqueeze(1) def forward(self, inputs_embeds, position_ids, cache_position, attn_mask, *kv_cache_flat): """ Args: inputs_embeds: [B, seq_len, 2048] position_ids: [3, B, seq_len] — MROPE positions (all 3 identical for TTS) cache_position: [seq_len] — indices into KV cache attn_mask: [B, 1, seq_len, MAX_SEQ_LEN] — causal attention mask *kv_cache_flat: 28 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim] Ordered as: k0, v0, k1, v1, ..., k27, v27 Returns: logits: [B, seq_len, 3072] *updated_kv_cache: same layout as input """ # Compute rotary embeddings cos, sin = self._compute_rope(position_ids, inputs_embeds.device, inputs_embeds.dtype) hidden_states = inputs_embeds updated_kv = [] for i, layer in enumerate(self.layers): k_cache = kv_cache_flat[i * 2] v_cache = kv_cache_flat[i * 2 + 1] hidden_states, new_k, new_v = layer( hidden_states, cos, sin, cache_position, k_cache, v_cache, attn_mask ) updated_kv.append(new_k) updated_kv.append(new_v) hidden_states = self.norm(hidden_states) logits = self.codec_head(hidden_states) return (logits, *updated_kv) # Build the wrapper print(" Constructing TalkerForExport...") t0 = time.time() export_talker = TalkerForExport(model.talker) export_talker.eval() print(f" Done in {time.time() - t0:.1f}s") param_count = sum(p.numel() for p in export_talker.parameters()) print(f" Parameters: {param_count / 1e9:.2f}B") # ── 3. Validate Wrapper ───────────────────────────────────────────── print("\n[3/6] Validating wrapper vs original (single forward pass)...") # Create test inputs seq_len = 10 test_embeds = torch.randn(BATCH_SIZE, seq_len, HIDDEN_SIZE) test_position_ids = torch.arange(seq_len).unsqueeze(0).unsqueeze(0).repeat(3, BATCH_SIZE, 1) test_cache_position = torch.arange(seq_len) # Causal mask: [B, 1, seq_len, MAX_SEQ_LEN] # -inf for positions beyond cache_position, 0 for valid causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf')) for i in range(seq_len): causal_mask[:, :, i, :i + 1] = 0.0 # Init KV cache as zeros kv_cache = [] for _ in range(NUM_LAYERS): kv_cache.append(torch.zeros(BATCH_SIZE, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM)) # K kv_cache.append(torch.zeros(BATCH_SIZE, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM)) # V with torch.no_grad(): outputs = export_talker(test_embeds, test_position_ids, test_cache_position, causal_mask, *kv_cache) logits = outputs[0] print(f" Logits shape: {list(logits.shape)}") # [1, 10, 3072] assert logits.shape == (BATCH_SIZE, seq_len, CODEC_VOCAB), f"Unexpected shape: {logits.shape}" print(f" KV cache outputs: {len(outputs) - 1} tensors") assert len(outputs) - 1 == NUM_LAYERS * 2, f"Expected {NUM_LAYERS * 2} KV tensors" print(" PASS — shapes correct") # Also validate a decode step (seq_len=1) decode_embeds = torch.randn(BATCH_SIZE, 1, HIDDEN_SIZE) decode_pos = torch.tensor([[[seq_len]]]).repeat(3, BATCH_SIZE, 1) decode_cache_pos = torch.tensor([seq_len]) # Update causal mask for decode: can attend to positions 0..seq_len decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf')) decode_mask[:, :, :, :seq_len + 1] = 0.0 # Use updated KV from prefill updated_kv = list(outputs[1:]) with torch.no_grad(): decode_out = export_talker(decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv) decode_logits = decode_out[0] print(f" Decode logits shape: {list(decode_logits.shape)}") # [1, 1, 3072] assert decode_logits.shape == (BATCH_SIZE, 1, CODEC_VOCAB) print(" PASS — decode step works") # ── 4. torch.export (prefill) ─────────────────────────────────────── print("\n[4/6] Running torch.export (prefill, seq_len=10)...") t0 = time.time() # Build example args for prefill prefill_args = ( test_embeds, test_position_ids, test_cache_position, causal_mask, *kv_cache, ) try: exported_prefill = torch.export.export( export_talker, prefill_args, strict=False, ) print(f" torch.export (prefill) succeeded in {time.time() - t0:.1f}s") print(f" Graph has {len(exported_prefill.graph.nodes)} nodes") except Exception as e: print(f" torch.export (prefill) FAILED: {e}") print(" This is expected for large models. Saving state_dict instead.") torch.save(export_talker.state_dict(), os.path.join(OUTPUT_DIR, "talker_state_dict.pt")) print(f" State dict saved to {OUTPUT_DIR}/talker_state_dict.pt") # Try with a minimal approach: just the decode step print("\n Trying torch.export with decode step (seq_len=1)...") decode_args = ( decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv, ) try: exported_decode = torch.export.export( export_talker, decode_args, strict=False, ) print(f" torch.export (decode) succeeded in {time.time() - t0:.1f}s") exported_prefill = None except Exception as e2: print(f" torch.export (decode) also FAILED: {e2}") exported_prefill = None exported_decode = None # ── 5. Lower to ExecuTorch .pte ───────────────────────────────────── print("\n[5/6] Lowering to ExecuTorch .pte...") t0 = time.time() # Try to lower whichever export succeeded ep = exported_prefill label = "prefill" if ep is None: if 'exported_decode' in dir() and exported_decode is not None: ep = exported_decode label = "decode" else: print(" No exported program available. Skipping .pte generation.") print(" State dict saved — can be loaded for on-device inference.") ep = None if ep is not None: try: from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner edge = to_edge_transform_and_lower( ep, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) et_program = edge.to_executorch() pte_path = os.path.join(OUTPUT_DIR, f"talker_{label}.pte") with open(pte_path, "wb") as f: f.write(et_program.buffer) pte_size = os.path.getsize(pte_path) / 1e6 print(f" .pte saved: {pte_path}") print(f" .pte size: {pte_size:.1f} MB") print(f" Lowered in {time.time() - t0:.1f}s") except Exception as e: print(f" ExecuTorch lowering failed: {e}") pt2_path = os.path.join(OUTPUT_DIR, f"talker_{label}.pt2") torch.export.save(ep, pt2_path) print(f" Saved exported program: {pt2_path}") # ── 6. Summary ─────────────────────────────────────────────────────── print("\n[6/6] Also saving embedding tables and text_projection for orchestration...") # Save embedding tables separately — these are needed for the Python # orchestration layer that constructs inputs_embeds torch.save({ "text_embedding": model.talker.model.text_embedding.state_dict(), "codec_embedding": model.talker.model.codec_embedding.state_dict(), "text_projection": model.talker.text_projection.state_dict(), }, os.path.join(OUTPUT_DIR, "talker_embeddings.pt")) print(f" Saved: {OUTPUT_DIR}/talker_embeddings.pt") print("\n" + "=" * 70) print("Phase 3 complete!") print(f" Max seq len: {MAX_SEQ_LEN}") print(f" KV cache per layer: 2 × [1, {NUM_KV_HEADS}, {MAX_SEQ_LEN}, {HEAD_DIM}]") print(f" Total KV cache (fp32): {2 * NUM_LAYERS * NUM_KV_HEADS * MAX_SEQ_LEN * HEAD_DIM * 4 / 1e6:.0f} MB") print(f" Codec vocab: {CODEC_VOCAB}") print("=" * 70)