#!/usr/bin/env python3 """ Phase 6: INT8 Weight-Only Quantization for All Modules ======================================================= Applies torchao int8_weight_only quantization to each module, re-exports to torch.export, and lowers to ExecuTorch .pte. int8_weight_only is INSTANT — no calibration data needed. """ import sys import os import copy import time import gc import torch import torch.nn as nn import torch.nn.functional as F MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") if VENV_SITE not in sys.path: sys.path.insert(0, VENV_SITE) if QWEN_TTS_SRC not in sys.path: sys.path.insert(0, QWEN_TTS_SRC) os.makedirs(OUTPUT_DIR, exist_ok=True) from torchao.quantization import quantize_, int8_weight_only from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner print("=" * 70) print("PHASE 6: INT8 Weight-Only Quantization") print("=" * 70) def export_and_lower_int8(module, example_args, name, output_dir): """Quantize, export, and lower a module to INT8 .pte.""" # Apply INT8 weight-only quantization print(f" Applying int8_weight_only quantization...") t0 = time.time() quantize_(module, int8_weight_only()) print(f" Quantized in {time.time() - t0:.1f}s") # torch.export print(f" Running torch.export...") t0 = time.time() exported = torch.export.export(module, example_args, strict=False) print(f" Exported in {time.time() - t0:.1f}s ({len(exported.graph.nodes)} nodes)") # Lower to .pte print(f" Lowering to ExecuTorch .pte...") t0 = time.time() edge = to_edge_transform_and_lower( exported, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) et_program = edge.to_executorch() pte_path = os.path.join(output_dir, f"{name}_int8.pte") with open(pte_path, "wb") as f: f.write(et_program.buffer) pte_size = os.path.getsize(pte_path) / 1e6 print(f" Saved: {pte_path} ({pte_size:.1f} MB)") print(f" Lowered in {time.time() - t0:.1f}s") return pte_size # ── Load base model ────────────────────────────────────────────────── print("\n[0/4] Loading base model...") from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) model = Qwen3TTSForConditionalGeneration.from_pretrained( MODEL_PATH, config=config, dtype=torch.float32, attn_implementation="sdpa", device_map="cpu", ) model.eval() print(" Model loaded.") results = {} # ═══════════════════════════════════════════════════════════════════ # 1. SPEAKER ENCODER # ═══════════════════════════════════════════════════════════════════ print("\n[1/4] Speaker Encoder INT8") # Inline the wrapper class class _ExplicitPadConv1d(nn.Module): def __init__(self, original_conv, pad_left, pad_right, pad_mode): super().__init__() self.conv = nn.Conv1d( in_channels=original_conv.in_channels, out_channels=original_conv.out_channels, kernel_size=original_conv.kernel_size[0], stride=original_conv.stride[0], padding=0, dilation=original_conv.dilation[0], groups=original_conv.groups, bias=original_conv.bias is not None) self.conv.weight = original_conv.weight if original_conv.bias is not None: self.conv.bias = original_conv.bias self.pad_left = pad_left self.pad_right = pad_right self.pad_mode = pad_mode def forward(self, x): if self.pad_left > 0 or self.pad_right > 0: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode) return self.conv(x) class SpeakerEncoderForExport_Q(nn.Module): def __init__(self, original_encoder): super().__init__() self.encoder = copy.deepcopy(original_encoder) self._fix_conv_padding(self.encoder) def _fix_conv_padding(self, module): for name, child in module.named_children(): if isinstance(child, nn.Conv1d) and child.padding == 'same': k = child.kernel_size[0] d = child.dilation[0] pad_total = d * (k - 1) new_conv = _ExplicitPadConv1d(child, pad_total // 2, pad_total - pad_total // 2, child.padding_mode) setattr(module, name, new_conv) else: self._fix_conv_padding(child) def forward(self, mel_input): return self.encoder(mel_input) FIXED_MEL_FRAMES = 469 se = SpeakerEncoderForExport_Q(model.speaker_encoder) se.eval() se_args = (torch.randn(1, FIXED_MEL_FRAMES, 128),) fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")) / 1e6 try: int8_size = export_and_lower_int8(se, se_args, "speaker_encoder", OUTPUT_DIR) results["speaker_encoder"] = {"fp32": fp32_size, "int8": int8_size} except Exception as e: print(f" FAILED: {e}") results["speaker_encoder"] = {"fp32": fp32_size, "int8": None, "error": str(e)} del se; gc.collect() # ═══════════════════════════════════════════════════════════════════ # 2. TALKER # ═══════════════════════════════════════════════════════════════════ print("\n[2/4] Talker INT8") # Re-use the TalkerForExport class inline — too large to duplicate, # so we import just the class from the module file # But to avoid re-executing the script, add the guard import importlib.util spec = importlib.util.spec_from_file_location( "export_talker_mod", os.path.join(os.path.dirname(os.path.abspath(__file__)), "export_talker.py") ) # We can't import it without running the whole script. # Instead, construct a simpler approach: quantize the state dict and re-export. # OR: just reconstruct a minimal version of the wrapper here. # Actually, let's take a pragmatic approach: torch.export the FP32 .pt2 we already # saved, apply quantization via torchao's pt2e path. # But that's more complex. Let's just reconstruct the key class. MAX_SEQ_LEN = 2048; NUM_LAYERS = 28; NUM_KV_HEADS = 8; HEAD_DIM = 128 NUM_HEADS = 16; HIDDEN_SIZE = 2048; CODEC_VOCAB = 3072 class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): dtype = x.dtype; x = x.float() v = x.pow(2).mean(-1, keepdim=True) return (self.weight * (x * torch.rsqrt(v + self.eps))).to(dtype) def rotate_half(x): x1 = x[..., :x.shape[-1]//2]; x2 = x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) class TalkerAttnQ(nn.Module): def __init__(self, orig, layer_idx): super().__init__() self.layer_idx = layer_idx; self.head_dim = HEAD_DIM self.num_heads = NUM_HEADS; self.num_kv_heads = NUM_KV_HEADS self.num_kv_groups = NUM_HEADS // NUM_KV_HEADS; self.scaling = HEAD_DIM**-0.5 self.q_proj = copy.deepcopy(orig.q_proj); self.k_proj = copy.deepcopy(orig.k_proj) self.v_proj = copy.deepcopy(orig.v_proj); self.o_proj = copy.deepcopy(orig.o_proj) self.q_norm = RMSNorm(HEAD_DIM); self.q_norm.weight = copy.deepcopy(orig.q_norm.weight) self.k_norm = RMSNorm(HEAD_DIM); self.k_norm.weight = copy.deepcopy(orig.k_norm.weight) def forward(self, h, cos, sin, cp, kc, vc, am): B, S, _ = h.shape q = self.q_norm(self.q_proj(h).view(B,S,self.num_heads,HEAD_DIM)).transpose(1,2) k = self.k_norm(self.k_proj(h).view(B,S,self.num_kv_heads,HEAD_DIM)).transpose(1,2) v = self.v_proj(h).view(B,S,self.num_kv_heads,HEAD_DIM).transpose(1,2) q = q*cos + rotate_half(q)*sin; k = k*cos + rotate_half(k)*sin kc = kc.clone(); vc = vc.clone() kc[:,:,cp,:] = k; vc[:,:,cp,:] = v ke = kc.unsqueeze(2).repeat(1,1,self.num_kv_groups,1,1).reshape(B,self.num_heads,MAX_SEQ_LEN,HEAD_DIM) ve = vc.unsqueeze(2).repeat(1,1,self.num_kv_groups,1,1).reshape(B,self.num_heads,MAX_SEQ_LEN,HEAD_DIM) o = F.scaled_dot_product_attention(q, ke, ve, attn_mask=am, scale=self.scaling) return self.o_proj(o.transpose(1,2).reshape(B,S,-1)), kc, vc class TalkerLayerQ(nn.Module): def __init__(self, orig, i): super().__init__() self.attn = TalkerAttnQ(orig.self_attn, i) self.gate_proj = copy.deepcopy(orig.mlp.gate_proj) self.up_proj = copy.deepcopy(orig.mlp.up_proj) self.down_proj = copy.deepcopy(orig.mlp.down_proj) self.n1 = RMSNorm(HIDDEN_SIZE); self.n1.weight = copy.deepcopy(orig.input_layernorm.weight) self.n2 = RMSNorm(HIDDEN_SIZE); self.n2.weight = copy.deepcopy(orig.post_attention_layernorm.weight) def forward(self, h, cos, sin, cp, kc, vc, am): r = h; a, kc, vc = self.attn(self.n1(h), cos, sin, cp, kc, vc, am); h = r + a r = h; x = self.n2(h); h = r + self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) return h, kc, vc class TalkerQ(nn.Module): def __init__(self, orig): super().__init__() self.layers = nn.ModuleList([TalkerLayerQ(l, i) for i, l in enumerate(orig.model.layers)]) self.norm = RMSNorm(HIDDEN_SIZE); self.norm.weight = copy.deepcopy(orig.model.norm.weight) self.codec_head = copy.deepcopy(orig.codec_head) self.register_buffer("inv_freq", orig.model.rotary_emb.inv_freq.clone()) self.rope_scaling = getattr(orig.model.rotary_emb, 'attention_scaling', 1.0) def forward(self, ie, pid, cp, am, *kv): pos = pid[0].float() freqs = pos.unsqueeze(-1) * self.inv_freq.float().unsqueeze(0).unsqueeze(0) emb = torch.cat([freqs, freqs], dim=-1) cos = (emb.cos() * self.rope_scaling).to(ie.dtype).unsqueeze(1) sin = (emb.sin() * self.rope_scaling).to(ie.dtype).unsqueeze(1) h = ie; ukv = [] for i, layer in enumerate(self.layers): h, nk, nv = layer(h, cos, sin, cp, kv[i*2], kv[i*2+1], am) ukv.append(nk); ukv.append(nv) return (self.codec_head(self.norm(h)), *ukv) t_mod = TalkerQ(model.talker); t_mod.eval() sl = 10 cm = torch.full((1,1,sl,MAX_SEQ_LEN), float('-inf')) for i in range(sl): cm[:,:,i,:i+1] = 0.0 t_args = ( torch.randn(1,sl,HIDDEN_SIZE), torch.arange(sl).unsqueeze(0).unsqueeze(0).repeat(3,1,1), torch.arange(sl), cm, *[torch.zeros(1,NUM_KV_HEADS,MAX_SEQ_LEN,HEAD_DIM) for _ in range(NUM_LAYERS*2)] ) fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "talker_prefill.pte")) / 1e6 try: int8_size = export_and_lower_int8(t_mod, t_args, "talker", OUTPUT_DIR) results["talker"] = {"fp32": fp32_size, "int8": int8_size} except Exception as e: print(f" FAILED: {e}") results["talker"] = {"fp32": fp32_size, "int8": None, "error": str(e)} del t_mod; gc.collect() # ═══════════════════════════════════════════════════════════════════ # 3. CODE PREDICTOR # ═══════════════════════════════════════════════════════════════════ print("\n[3/4] Code Predictor INT8") CP_MAX = 17; CPL = 5; CPKV = 8; CPHD = 128; CPH = 16; CPHS = 1024; THD = 2048 class CPAttnQ(nn.Module): def __init__(self, orig, i): super().__init__() self.q_proj = copy.deepcopy(orig.q_proj); self.k_proj = copy.deepcopy(orig.k_proj) self.v_proj = copy.deepcopy(orig.v_proj); self.o_proj = copy.deepcopy(orig.o_proj) self.q_norm = RMSNorm(CPHD); self.q_norm.weight = copy.deepcopy(orig.q_norm.weight) self.k_norm = RMSNorm(CPHD); self.k_norm.weight = copy.deepcopy(orig.k_norm.weight) self.g = CPH // CPKV def forward(self, h, cos, sin, cp, kc, vc, am): B,S,_ = h.shape q = self.q_norm(self.q_proj(h).view(B,S,CPH,CPHD)).transpose(1,2) k = self.k_norm(self.k_proj(h).view(B,S,CPKV,CPHD)).transpose(1,2) v = self.v_proj(h).view(B,S,CPKV,CPHD).transpose(1,2) q = q*cos + rotate_half(q)*sin; k = k*cos + rotate_half(k)*sin kc = kc.clone(); vc = vc.clone(); kc[:,:,cp,:] = k; vc[:,:,cp,:] = v ke = kc.unsqueeze(2).repeat(1,1,self.g,1,1).reshape(B,CPH,CP_MAX,CPHD) ve = vc.unsqueeze(2).repeat(1,1,self.g,1,1).reshape(B,CPH,CP_MAX,CPHD) o = F.scaled_dot_product_attention(q,ke,ve,attn_mask=am,scale=CPHD**-0.5) return self.o_proj(o.transpose(1,2).reshape(B,S,-1)), kc, vc class CPLayerQ(nn.Module): def __init__(self, orig, i): super().__init__() self.attn = CPAttnQ(orig.self_attn, i) self.gp = copy.deepcopy(orig.mlp.gate_proj) self.up = copy.deepcopy(orig.mlp.up_proj) self.dp = copy.deepcopy(orig.mlp.down_proj) self.n1 = RMSNorm(CPHS); self.n1.weight = copy.deepcopy(orig.input_layernorm.weight) self.n2 = RMSNorm(CPHS); self.n2.weight = copy.deepcopy(orig.post_attention_layernorm.weight) def forward(self, h, cos, sin, cp, kc, vc, am): r=h; a,kc,vc = self.attn(self.n1(h),cos,sin,cp,kc,vc,am); h=r+a r=h; x=self.n2(h); h=r+self.dp(F.silu(self.gp(x))*self.up(x)) return h, kc, vc class CPQ(nn.Module): def __init__(self, orig): super().__init__() self.layers = nn.ModuleList([CPLayerQ(l,i) for i,l in enumerate(orig.model.layers)]) self.norm = RMSNorm(CPHS); self.norm.weight = copy.deepcopy(orig.model.norm.weight) self.proj = copy.deepcopy(orig.small_to_mtp_projection) self.register_buffer("inv_freq", orig.model.rotary_emb.inv_freq.clone()) self.rs = getattr(orig.model.rotary_emb, 'attention_scaling', 1.0) def forward(self, ie, pid, cp, am, *kv): h = self.proj(ie) pos = pid.float() freqs = pos.unsqueeze(-1)*self.inv_freq.float().unsqueeze(0).unsqueeze(0) emb = torch.cat([freqs,freqs],dim=-1) cos = (emb.cos()*self.rs).to(h.dtype).unsqueeze(1) sin = (emb.sin()*self.rs).to(h.dtype).unsqueeze(1) ukv = [] for i, l in enumerate(self.layers): h,nk,nv = l(h,cos,sin,cp,kv[i*2],kv[i*2+1],am); ukv.append(nk); ukv.append(nv) return (self.norm(h), *ukv) cp_mod = CPQ(model.talker.code_predictor); cp_mod.eval() csl = 2 ccm = torch.full((1,1,csl,CP_MAX), float('-inf')) for i in range(csl): ccm[:,:,i,:i+1] = 0.0 cp_args = ( torch.randn(1,csl,THD), torch.arange(csl).unsqueeze(0), torch.arange(csl), ccm, *[torch.zeros(1,CPKV,CP_MAX,CPHD) for _ in range(CPL*2)] ) fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "code_predictor.pte")) / 1e6 try: int8_size = export_and_lower_int8(cp_mod, cp_args, "code_predictor", OUTPUT_DIR) results["code_predictor"] = {"fp32": fp32_size, "int8": int8_size} except Exception as e: print(f" FAILED: {e}") results["code_predictor"] = {"fp32": fp32_size, "int8": None, "error": str(e)} del cp_mod; gc.collect() # ═══════════════════════════════════════════════════════════════════ # 4. VOCODER # ═══════════════════════════════════════════════════════════════════ print("\n[4/4] Vocoder INT8") class VocQ(nn.Module): def __init__(self, dec): super().__init__() self.decoder = copy.deepcopy(dec) def forward(self, codes): return self.decoder(codes) v_mod = VocQ(model.speech_tokenizer.model.decoder); v_mod.eval() v_args = (torch.randint(0, 2048, (1, 16, 50)),) fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "vocoder.pte")) / 1e6 try: int8_size = export_and_lower_int8(v_mod, v_args, "vocoder", OUTPUT_DIR) results["vocoder"] = {"fp32": fp32_size, "int8": int8_size} except Exception as e: print(f" FAILED: {e}") results["vocoder"] = {"fp32": fp32_size, "int8": None, "error": str(e)} del v_mod; gc.collect() # ── Summary ────────────────────────────────────────────────────────── print("\n" + "=" * 70) print("QUANTIZATION SUMMARY") print("=" * 70) print(f"\n{'Module':25s} {'FP32 (MB)':>12s} {'INT8 (MB)':>12s} {'Reduction':>10s}") print("-" * 60) total_fp32 = 0; total_int8 = 0 for name, r in results.items(): fp32 = r.get("fp32", 0) or 0 int8 = r.get("int8") total_fp32 += fp32 if int8 is not None: total_int8 += int8 red = f"{fp32/int8:.1f}x" if int8 > 0 else "∞" else: red = f"FAILED: {r.get('error','')[:40]}" int8 = 0 print(f" {name:23s} {fp32:10.1f} {int8:10.1f} {red}") print("-" * 60) ovr = f"{total_fp32/total_int8:.1f}x" if total_int8 > 0 else "N/A" print(f" {'TOTAL':23s} {total_fp32:10.1f} {total_int8:10.1f} {ovr}") print("\nPhase 6 complete!")