Qwen3-TTS-1.7B-Base-ExecuTorch / scripts /export_code_predictor.py
acul3's picture
Upload scripts/export_code_predictor.py with huggingface_hub
4005d54 verified
#!/usr/bin/env python3
"""
Phase 4: Export Code Predictor to ExecuTorch .pte
==================================================
The code predictor is a smaller 5-layer transformer (175M params) that
takes the talker's hidden state + first codebook token and autoregressively
generates the remaining 15 codebook tokens.
Architecture:
- hidden_size=1024, 5 layers, 16 heads, 8 kv_heads, head_dim=128
- small_to_mtp_projection: Linear(2048β†’1024) β€” projects talker hidden β†’ predictor
- 15 lm_heads: Linear(1024β†’2048) each (one per code group)
- 15 codec_embeddings: Embedding(2048, 2048) each
During inference (called once per talker decode step):
Step 0 (prefill): concat(projected_talker_hidden, codec_embed_0(first_token)) β†’ 2 tokens
Steps 1-14: predict next code group token β†’ embed it β†’ feed back
We export this as a static-KV-cache transformer similar to the talker.
"""
import sys
import os
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ── Configuration ────────────────────────────────────────────────────
MAX_SEQ_LEN = 17 # prefill=2, then 15 decode steps
BATCH_SIZE = 1
CP_NUM_LAYERS = 5
CP_NUM_KV_HEADS = 8
CP_HEAD_DIM = 128
CP_NUM_HEADS = 16
CP_HIDDEN_SIZE = 1024
CP_INTERMEDIATE_SIZE = 3072
CP_VOCAB_SIZE = 2048
CP_NUM_CODE_GROUPS = 16 # total groups (predict 15, first comes from talker)
TALKER_HIDDEN_SIZE = 2048
print("=" * 70)
print("PHASE 4: Export Code Predictor β†’ .pte")
print("=" * 70)
# ── 1. Load Model ───────────────────────────────────────────────────
print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
# ── 2. Build Export-Ready Code Predictor ─────────────────────────────
print("\n[2/5] Building export-ready code predictor wrapper...")
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
dtype = x.dtype
x = x.float()
v = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(v + self.eps)
return (self.weight * x).to(dtype)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CPAttentionForExport(nn.Module):
"""Code predictor attention layer with static KV cache."""
def __init__(self, original_attn, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = CP_HEAD_DIM
self.num_heads = CP_NUM_HEADS
self.num_kv_heads = CP_NUM_KV_HEADS
self.num_kv_groups = CP_NUM_HEADS // CP_NUM_KV_HEADS
self.scaling = CP_HEAD_DIM ** -0.5
self.q_proj = copy.deepcopy(original_attn.q_proj)
self.k_proj = copy.deepcopy(original_attn.k_proj)
self.v_proj = copy.deepcopy(original_attn.v_proj)
self.o_proj = copy.deepcopy(original_attn.o_proj)
self.q_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6)
self.q_norm.weight = copy.deepcopy(original_attn.q_norm.weight)
self.k_norm = RMSNorm(CP_HEAD_DIM, eps=1e-6)
self.k_norm.weight = copy.deepcopy(original_attn.k_norm.weight)
def forward(self, hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask):
bsz, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim)
q = self.q_norm(q).transpose(1, 2)
k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim)
k = self.k_norm(k).transpose(1, 2)
v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
k_cache = k_cache.clone()
v_cache = v_cache.clone()
k_cache[:, :, cache_position, :] = k
v_cache[:, :, cache_position, :] = v
k_expanded = k_cache.unsqueeze(2).repeat(
1, 1, self.num_kv_groups, 1, 1
).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim)
v_expanded = v_cache.unsqueeze(2).repeat(
1, 1, self.num_kv_groups, 1, 1
).reshape(bsz, self.num_heads, MAX_SEQ_LEN, self.head_dim)
attn_output = F.scaled_dot_product_attention(
q, k_expanded, v_expanded,
attn_mask=attn_mask,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, k_cache, v_cache
class CPMLP(nn.Module):
def __init__(self, original_mlp):
super().__init__()
self.gate_proj = copy.deepcopy(original_mlp.gate_proj)
self.up_proj = copy.deepcopy(original_mlp.up_proj)
self.down_proj = copy.deepcopy(original_mlp.down_proj)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class CPLayerForExport(nn.Module):
def __init__(self, original_layer, layer_idx):
super().__init__()
self.attn = CPAttentionForExport(original_layer.self_attn, layer_idx)
self.mlp = CPMLP(original_layer.mlp)
self.input_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
self.input_norm.weight = copy.deepcopy(original_layer.input_layernorm.weight)
self.post_attn_norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
self.post_attn_norm.weight = copy.deepcopy(original_layer.post_attention_layernorm.weight)
def forward(self, hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask):
residual = hidden_states
hidden_states = self.input_norm(hidden_states)
attn_out, k_cache, v_cache = self.attn(
hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.post_attn_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, k_cache, v_cache
class CodePredictorForExport(nn.Module):
"""
Export-ready code predictor backbone.
Input: pre-projected inputs_embeds (already through small_to_mtp_projection)
Output: hidden states (caller applies the appropriate lm_head externally)
For the full 16-codebook prediction:
1. Python builds inputs_embeds from talker hidden + codec embeddings
2. This module runs the transformer
3. Python takes hidden[:, step_idx, :] and applies lm_head[step_idx]
"""
def __init__(self, original_cp):
super().__init__()
# Transformer layers
self.layers = nn.ModuleList()
for i, layer in enumerate(original_cp.model.layers):
self.layers.append(CPLayerForExport(layer, i))
# Final norm
self.norm = RMSNorm(CP_HIDDEN_SIZE, eps=1e-6)
self.norm.weight = copy.deepcopy(original_cp.model.norm.weight)
# Projection from talker hidden to code predictor hidden
self.small_to_mtp_projection = copy.deepcopy(original_cp.small_to_mtp_projection)
# LM heads (15 heads, one per code group 1..15)
self.lm_heads = nn.ModuleList()
for head in original_cp.lm_head:
self.lm_heads.append(copy.deepcopy(head))
# Rotary embedding
orig_rope = original_cp.model.rotary_emb
self.register_buffer("inv_freq", orig_rope.inv_freq.clone())
self.rope_scaling = getattr(orig_rope, 'attention_scaling', 1.0)
def _compute_rope(self, position_ids, device, dtype):
pos = position_ids.float() # [B, seq_len]
inv_freq = self.inv_freq.float().to(device)
freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)
emb = torch.cat([freqs, freqs], dim=-1)
cos = (emb.cos() * self.rope_scaling).to(dtype)
sin = (emb.sin() * self.rope_scaling).to(dtype)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, inputs_embeds, position_ids, cache_position, attn_mask,
*kv_cache_flat):
"""
Args:
inputs_embeds: [B, seq_len, talker_hidden_size] β€” NOT YET projected
position_ids: [B, seq_len]
cache_position: [seq_len]
attn_mask: [B, 1, seq_len, MAX_SEQ_LEN]
*kv_cache_flat: 5 * 2 tensors, each [B, kv_heads, MAX_SEQ_LEN, head_dim]
Returns:
hidden_states: [B, seq_len, CP_HIDDEN_SIZE] β€” apply lm_head externally
*updated_kv_cache
"""
# Project from talker hidden β†’ code predictor hidden
hidden_states = self.small_to_mtp_projection(inputs_embeds)
cos, sin = self._compute_rope(position_ids, hidden_states.device, hidden_states.dtype)
updated_kv = []
for i, layer in enumerate(self.layers):
k_cache = kv_cache_flat[i * 2]
v_cache = kv_cache_flat[i * 2 + 1]
hidden_states, new_k, new_v = layer(
hidden_states, cos, sin, cache_position,
k_cache, v_cache, attn_mask
)
updated_kv.append(new_k)
updated_kv.append(new_v)
hidden_states = self.norm(hidden_states)
return (hidden_states, *updated_kv)
print(" Constructing CodePredictorForExport...")
t0 = time.time()
export_cp = CodePredictorForExport(model.talker.code_predictor)
export_cp.eval()
print(f" Done in {time.time() - t0:.1f}s")
param_count = sum(p.numel() for p in export_cp.parameters())
print(f" Parameters: {param_count / 1e6:.1f}M")
# ── 3. Validate ─────────────────────────────────────────────────────
print("\n[3/5] Validating wrapper...")
# Prefill: 2 tokens (projected_talker_hidden + first_codec_embed)
seq_len = 2
test_embeds = torch.randn(BATCH_SIZE, seq_len, TALKER_HIDDEN_SIZE)
test_pos = torch.arange(seq_len).unsqueeze(0).expand(BATCH_SIZE, -1)
test_cache_pos = torch.arange(seq_len)
causal_mask = torch.full((BATCH_SIZE, 1, seq_len, MAX_SEQ_LEN), float('-inf'))
for i in range(seq_len):
causal_mask[:, :, i, :i + 1] = 0.0
kv_cache = []
for _ in range(CP_NUM_LAYERS):
kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM))
kv_cache.append(torch.zeros(BATCH_SIZE, CP_NUM_KV_HEADS, MAX_SEQ_LEN, CP_HEAD_DIM))
with torch.no_grad():
outputs = export_cp(test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache)
hidden = outputs[0]
print(f" Hidden states shape: {list(hidden.shape)}") # [1, 2, 1024]
assert hidden.shape == (BATCH_SIZE, seq_len, CP_HIDDEN_SIZE)
# Apply lm_head to get logits for the first prediction step
logits_0 = export_cp.lm_heads[0](hidden[:, -1:, :])
print(f" Logits[0] shape: {list(logits_0.shape)}") # [1, 1, 2048]
assert logits_0.shape[-1] == CP_VOCAB_SIZE
# Decode step
decode_embeds = torch.randn(BATCH_SIZE, 1, TALKER_HIDDEN_SIZE)
decode_pos = torch.tensor([[seq_len]])
decode_cache_pos = torch.tensor([seq_len])
decode_mask = torch.full((BATCH_SIZE, 1, 1, MAX_SEQ_LEN), float('-inf'))
decode_mask[:, :, :, :seq_len + 1] = 0.0
updated_kv = list(outputs[1:])
with torch.no_grad():
decode_out = export_cp(decode_embeds, decode_pos, decode_cache_pos, decode_mask, *updated_kv)
print(f" Decode hidden shape: {list(decode_out[0].shape)}")
print(" PASS β€” code predictor validated")
# ── 4. torch.export ─────────────────────────────────────────────────
print("\n[4/5] Running torch.export...")
t0 = time.time()
prefill_args = (test_embeds, test_pos, test_cache_pos, causal_mask, *kv_cache)
try:
exported = torch.export.export(export_cp, prefill_args, strict=False)
print(f" torch.export succeeded in {time.time() - t0:.1f}s")
print(f" Graph nodes: {len(exported.graph.nodes)}")
except Exception as e:
print(f" torch.export FAILED: {e}")
exported = None
# ── 5. Lower to .pte ────────────────────────────────────────────────
print("\n[5/5] Lowering to ExecuTorch .pte...")
t0 = time.time()
if exported is not None:
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(OUTPUT_DIR, "code_predictor.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size = os.path.getsize(pte_path) / 1e6
print(f" .pte saved: {pte_path}")
print(f" .pte size: {pte_size:.1f} MB")
print(f" Lowered in {time.time() - t0:.1f}s")
except Exception as e:
print(f" ExecuTorch lowering failed: {e}")
pt2_path = os.path.join(OUTPUT_DIR, "code_predictor.pt2")
torch.export.save(exported, pt2_path)
print(f" Saved: {pt2_path}")
# Also save the codec embeddings and lm_heads for the orchestration layer
torch.save({
"codec_embeddings": [emb.state_dict() for emb in model.talker.code_predictor.model.codec_embedding],
"lm_heads": [head.state_dict() for head in export_cp.lm_heads],
"small_to_mtp_projection": export_cp.small_to_mtp_projection.state_dict(),
}, os.path.join(OUTPUT_DIR, "code_predictor_extras.pt"))
print(f" Saved codec embeddings + lm_heads: {OUTPUT_DIR}/code_predictor_extras.pt")
print("\n" + "=" * 70)
print("Phase 4 complete!")
print(f" Max seq len: {MAX_SEQ_LEN}")
print(f" Parameters: {param_count / 1e6:.1f}M")
print(f" Vocab: {CP_VOCAB_SIZE}, Code groups: {CP_NUM_CODE_GROUPS}")
print("=" * 70)