acul3's picture
Upload scripts/quantize_all.py with huggingface_hub
3fba683 verified
#!/usr/bin/env python3
"""
Phase 6: INT8 Weight-Only Quantization for All Modules
=======================================================
Applies torchao int8_weight_only quantization to each module,
re-exports to torch.export, and lowers to ExecuTorch .pte.
int8_weight_only is INSTANT β€” no calibration data needed.
"""
import sys
import os
import copy
import time
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
from torchao.quantization import quantize_, int8_weight_only
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
print("=" * 70)
print("PHASE 6: INT8 Weight-Only Quantization")
print("=" * 70)
def export_and_lower_int8(module, example_args, name, output_dir):
"""Quantize, export, and lower a module to INT8 .pte."""
# Apply INT8 weight-only quantization
print(f" Applying int8_weight_only quantization...")
t0 = time.time()
quantize_(module, int8_weight_only())
print(f" Quantized in {time.time() - t0:.1f}s")
# torch.export
print(f" Running torch.export...")
t0 = time.time()
exported = torch.export.export(module, example_args, strict=False)
print(f" Exported in {time.time() - t0:.1f}s ({len(exported.graph.nodes)} nodes)")
# Lower to .pte
print(f" Lowering to ExecuTorch .pte...")
t0 = time.time()
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(output_dir, f"{name}_int8.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size = os.path.getsize(pte_path) / 1e6
print(f" Saved: {pte_path} ({pte_size:.1f} MB)")
print(f" Lowered in {time.time() - t0:.1f}s")
return pte_size
# ── Load base model ──────────────────────────────────────────────────
print("\n[0/4] Loading base model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
results = {}
# ═══════════════════════════════════════════════════════════════════
# 1. SPEAKER ENCODER
# ═══════════════════════════════════════════════════════════════════
print("\n[1/4] Speaker Encoder INT8")
# Inline the wrapper class
class _ExplicitPadConv1d(nn.Module):
def __init__(self, original_conv, pad_left, pad_right, pad_mode):
super().__init__()
self.conv = nn.Conv1d(
in_channels=original_conv.in_channels, out_channels=original_conv.out_channels,
kernel_size=original_conv.kernel_size[0], stride=original_conv.stride[0],
padding=0, dilation=original_conv.dilation[0], groups=original_conv.groups,
bias=original_conv.bias is not None)
self.conv.weight = original_conv.weight
if original_conv.bias is not None:
self.conv.bias = original_conv.bias
self.pad_left = pad_left
self.pad_right = pad_right
self.pad_mode = pad_mode
def forward(self, x):
if self.pad_left > 0 or self.pad_right > 0:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode)
return self.conv(x)
class SpeakerEncoderForExport_Q(nn.Module):
def __init__(self, original_encoder):
super().__init__()
self.encoder = copy.deepcopy(original_encoder)
self._fix_conv_padding(self.encoder)
def _fix_conv_padding(self, module):
for name, child in module.named_children():
if isinstance(child, nn.Conv1d) and child.padding == 'same':
k = child.kernel_size[0]
d = child.dilation[0]
pad_total = d * (k - 1)
new_conv = _ExplicitPadConv1d(child, pad_total // 2, pad_total - pad_total // 2, child.padding_mode)
setattr(module, name, new_conv)
else:
self._fix_conv_padding(child)
def forward(self, mel_input):
return self.encoder(mel_input)
FIXED_MEL_FRAMES = 469
se = SpeakerEncoderForExport_Q(model.speaker_encoder)
se.eval()
se_args = (torch.randn(1, FIXED_MEL_FRAMES, 128),)
fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")) / 1e6
try:
int8_size = export_and_lower_int8(se, se_args, "speaker_encoder", OUTPUT_DIR)
results["speaker_encoder"] = {"fp32": fp32_size, "int8": int8_size}
except Exception as e:
print(f" FAILED: {e}")
results["speaker_encoder"] = {"fp32": fp32_size, "int8": None, "error": str(e)}
del se; gc.collect()
# ═══════════════════════════════════════════════════════════════════
# 2. TALKER
# ═══════════════════════════════════════════════════════════════════
print("\n[2/4] Talker INT8")
# Re-use the TalkerForExport class inline β€” too large to duplicate,
# so we import just the class from the module file
# But to avoid re-executing the script, add the guard
import importlib.util
spec = importlib.util.spec_from_file_location(
"export_talker_mod",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "export_talker.py")
)
# We can't import it without running the whole script.
# Instead, construct a simpler approach: quantize the state dict and re-export.
# OR: just reconstruct a minimal version of the wrapper here.
# Actually, let's take a pragmatic approach: torch.export the FP32 .pt2 we already
# saved, apply quantization via torchao's pt2e path.
# But that's more complex. Let's just reconstruct the key class.
MAX_SEQ_LEN = 2048; NUM_LAYERS = 28; NUM_KV_HEADS = 8; HEAD_DIM = 128
NUM_HEADS = 16; HIDDEN_SIZE = 2048; CODEC_VOCAB = 3072
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
dtype = x.dtype; x = x.float()
v = x.pow(2).mean(-1, keepdim=True)
return (self.weight * (x * torch.rsqrt(v + self.eps))).to(dtype)
def rotate_half(x):
x1 = x[..., :x.shape[-1]//2]; x2 = x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
class TalkerAttnQ(nn.Module):
def __init__(self, orig, layer_idx):
super().__init__()
self.layer_idx = layer_idx; self.head_dim = HEAD_DIM
self.num_heads = NUM_HEADS; self.num_kv_heads = NUM_KV_HEADS
self.num_kv_groups = NUM_HEADS // NUM_KV_HEADS; self.scaling = HEAD_DIM**-0.5
self.q_proj = copy.deepcopy(orig.q_proj); self.k_proj = copy.deepcopy(orig.k_proj)
self.v_proj = copy.deepcopy(orig.v_proj); self.o_proj = copy.deepcopy(orig.o_proj)
self.q_norm = RMSNorm(HEAD_DIM); self.q_norm.weight = copy.deepcopy(orig.q_norm.weight)
self.k_norm = RMSNorm(HEAD_DIM); self.k_norm.weight = copy.deepcopy(orig.k_norm.weight)
def forward(self, h, cos, sin, cp, kc, vc, am):
B, S, _ = h.shape
q = self.q_norm(self.q_proj(h).view(B,S,self.num_heads,HEAD_DIM)).transpose(1,2)
k = self.k_norm(self.k_proj(h).view(B,S,self.num_kv_heads,HEAD_DIM)).transpose(1,2)
v = self.v_proj(h).view(B,S,self.num_kv_heads,HEAD_DIM).transpose(1,2)
q = q*cos + rotate_half(q)*sin; k = k*cos + rotate_half(k)*sin
kc = kc.clone(); vc = vc.clone()
kc[:,:,cp,:] = k; vc[:,:,cp,:] = v
ke = kc.unsqueeze(2).repeat(1,1,self.num_kv_groups,1,1).reshape(B,self.num_heads,MAX_SEQ_LEN,HEAD_DIM)
ve = vc.unsqueeze(2).repeat(1,1,self.num_kv_groups,1,1).reshape(B,self.num_heads,MAX_SEQ_LEN,HEAD_DIM)
o = F.scaled_dot_product_attention(q, ke, ve, attn_mask=am, scale=self.scaling)
return self.o_proj(o.transpose(1,2).reshape(B,S,-1)), kc, vc
class TalkerLayerQ(nn.Module):
def __init__(self, orig, i):
super().__init__()
self.attn = TalkerAttnQ(orig.self_attn, i)
self.gate_proj = copy.deepcopy(orig.mlp.gate_proj)
self.up_proj = copy.deepcopy(orig.mlp.up_proj)
self.down_proj = copy.deepcopy(orig.mlp.down_proj)
self.n1 = RMSNorm(HIDDEN_SIZE); self.n1.weight = copy.deepcopy(orig.input_layernorm.weight)
self.n2 = RMSNorm(HIDDEN_SIZE); self.n2.weight = copy.deepcopy(orig.post_attention_layernorm.weight)
def forward(self, h, cos, sin, cp, kc, vc, am):
r = h; a, kc, vc = self.attn(self.n1(h), cos, sin, cp, kc, vc, am); h = r + a
r = h; x = self.n2(h); h = r + self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return h, kc, vc
class TalkerQ(nn.Module):
def __init__(self, orig):
super().__init__()
self.layers = nn.ModuleList([TalkerLayerQ(l, i) for i, l in enumerate(orig.model.layers)])
self.norm = RMSNorm(HIDDEN_SIZE); self.norm.weight = copy.deepcopy(orig.model.norm.weight)
self.codec_head = copy.deepcopy(orig.codec_head)
self.register_buffer("inv_freq", orig.model.rotary_emb.inv_freq.clone())
self.rope_scaling = getattr(orig.model.rotary_emb, 'attention_scaling', 1.0)
def forward(self, ie, pid, cp, am, *kv):
pos = pid[0].float()
freqs = pos.unsqueeze(-1) * self.inv_freq.float().unsqueeze(0).unsqueeze(0)
emb = torch.cat([freqs, freqs], dim=-1)
cos = (emb.cos() * self.rope_scaling).to(ie.dtype).unsqueeze(1)
sin = (emb.sin() * self.rope_scaling).to(ie.dtype).unsqueeze(1)
h = ie; ukv = []
for i, layer in enumerate(self.layers):
h, nk, nv = layer(h, cos, sin, cp, kv[i*2], kv[i*2+1], am)
ukv.append(nk); ukv.append(nv)
return (self.codec_head(self.norm(h)), *ukv)
t_mod = TalkerQ(model.talker); t_mod.eval()
sl = 10
cm = torch.full((1,1,sl,MAX_SEQ_LEN), float('-inf'))
for i in range(sl): cm[:,:,i,:i+1] = 0.0
t_args = (
torch.randn(1,sl,HIDDEN_SIZE),
torch.arange(sl).unsqueeze(0).unsqueeze(0).repeat(3,1,1),
torch.arange(sl), cm,
*[torch.zeros(1,NUM_KV_HEADS,MAX_SEQ_LEN,HEAD_DIM) for _ in range(NUM_LAYERS*2)]
)
fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "talker_prefill.pte")) / 1e6
try:
int8_size = export_and_lower_int8(t_mod, t_args, "talker", OUTPUT_DIR)
results["talker"] = {"fp32": fp32_size, "int8": int8_size}
except Exception as e:
print(f" FAILED: {e}")
results["talker"] = {"fp32": fp32_size, "int8": None, "error": str(e)}
del t_mod; gc.collect()
# ═══════════════════════════════════════════════════════════════════
# 3. CODE PREDICTOR
# ═══════════════════════════════════════════════════════════════════
print("\n[3/4] Code Predictor INT8")
CP_MAX = 17; CPL = 5; CPKV = 8; CPHD = 128; CPH = 16; CPHS = 1024; THD = 2048
class CPAttnQ(nn.Module):
def __init__(self, orig, i):
super().__init__()
self.q_proj = copy.deepcopy(orig.q_proj); self.k_proj = copy.deepcopy(orig.k_proj)
self.v_proj = copy.deepcopy(orig.v_proj); self.o_proj = copy.deepcopy(orig.o_proj)
self.q_norm = RMSNorm(CPHD); self.q_norm.weight = copy.deepcopy(orig.q_norm.weight)
self.k_norm = RMSNorm(CPHD); self.k_norm.weight = copy.deepcopy(orig.k_norm.weight)
self.g = CPH // CPKV
def forward(self, h, cos, sin, cp, kc, vc, am):
B,S,_ = h.shape
q = self.q_norm(self.q_proj(h).view(B,S,CPH,CPHD)).transpose(1,2)
k = self.k_norm(self.k_proj(h).view(B,S,CPKV,CPHD)).transpose(1,2)
v = self.v_proj(h).view(B,S,CPKV,CPHD).transpose(1,2)
q = q*cos + rotate_half(q)*sin; k = k*cos + rotate_half(k)*sin
kc = kc.clone(); vc = vc.clone(); kc[:,:,cp,:] = k; vc[:,:,cp,:] = v
ke = kc.unsqueeze(2).repeat(1,1,self.g,1,1).reshape(B,CPH,CP_MAX,CPHD)
ve = vc.unsqueeze(2).repeat(1,1,self.g,1,1).reshape(B,CPH,CP_MAX,CPHD)
o = F.scaled_dot_product_attention(q,ke,ve,attn_mask=am,scale=CPHD**-0.5)
return self.o_proj(o.transpose(1,2).reshape(B,S,-1)), kc, vc
class CPLayerQ(nn.Module):
def __init__(self, orig, i):
super().__init__()
self.attn = CPAttnQ(orig.self_attn, i)
self.gp = copy.deepcopy(orig.mlp.gate_proj)
self.up = copy.deepcopy(orig.mlp.up_proj)
self.dp = copy.deepcopy(orig.mlp.down_proj)
self.n1 = RMSNorm(CPHS); self.n1.weight = copy.deepcopy(orig.input_layernorm.weight)
self.n2 = RMSNorm(CPHS); self.n2.weight = copy.deepcopy(orig.post_attention_layernorm.weight)
def forward(self, h, cos, sin, cp, kc, vc, am):
r=h; a,kc,vc = self.attn(self.n1(h),cos,sin,cp,kc,vc,am); h=r+a
r=h; x=self.n2(h); h=r+self.dp(F.silu(self.gp(x))*self.up(x))
return h, kc, vc
class CPQ(nn.Module):
def __init__(self, orig):
super().__init__()
self.layers = nn.ModuleList([CPLayerQ(l,i) for i,l in enumerate(orig.model.layers)])
self.norm = RMSNorm(CPHS); self.norm.weight = copy.deepcopy(orig.model.norm.weight)
self.proj = copy.deepcopy(orig.small_to_mtp_projection)
self.register_buffer("inv_freq", orig.model.rotary_emb.inv_freq.clone())
self.rs = getattr(orig.model.rotary_emb, 'attention_scaling', 1.0)
def forward(self, ie, pid, cp, am, *kv):
h = self.proj(ie)
pos = pid.float()
freqs = pos.unsqueeze(-1)*self.inv_freq.float().unsqueeze(0).unsqueeze(0)
emb = torch.cat([freqs,freqs],dim=-1)
cos = (emb.cos()*self.rs).to(h.dtype).unsqueeze(1)
sin = (emb.sin()*self.rs).to(h.dtype).unsqueeze(1)
ukv = []
for i, l in enumerate(self.layers):
h,nk,nv = l(h,cos,sin,cp,kv[i*2],kv[i*2+1],am); ukv.append(nk); ukv.append(nv)
return (self.norm(h), *ukv)
cp_mod = CPQ(model.talker.code_predictor); cp_mod.eval()
csl = 2
ccm = torch.full((1,1,csl,CP_MAX), float('-inf'))
for i in range(csl): ccm[:,:,i,:i+1] = 0.0
cp_args = (
torch.randn(1,csl,THD), torch.arange(csl).unsqueeze(0), torch.arange(csl), ccm,
*[torch.zeros(1,CPKV,CP_MAX,CPHD) for _ in range(CPL*2)]
)
fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "code_predictor.pte")) / 1e6
try:
int8_size = export_and_lower_int8(cp_mod, cp_args, "code_predictor", OUTPUT_DIR)
results["code_predictor"] = {"fp32": fp32_size, "int8": int8_size}
except Exception as e:
print(f" FAILED: {e}")
results["code_predictor"] = {"fp32": fp32_size, "int8": None, "error": str(e)}
del cp_mod; gc.collect()
# ═══════════════════════════════════════════════════════════════════
# 4. VOCODER
# ═══════════════════════════════════════════════════════════════════
print("\n[4/4] Vocoder INT8")
class VocQ(nn.Module):
def __init__(self, dec):
super().__init__()
self.decoder = copy.deepcopy(dec)
def forward(self, codes):
return self.decoder(codes)
v_mod = VocQ(model.speech_tokenizer.model.decoder); v_mod.eval()
v_args = (torch.randint(0, 2048, (1, 16, 50)),)
fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "vocoder.pte")) / 1e6
try:
int8_size = export_and_lower_int8(v_mod, v_args, "vocoder", OUTPUT_DIR)
results["vocoder"] = {"fp32": fp32_size, "int8": int8_size}
except Exception as e:
print(f" FAILED: {e}")
results["vocoder"] = {"fp32": fp32_size, "int8": None, "error": str(e)}
del v_mod; gc.collect()
# ── Summary ──────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("QUANTIZATION SUMMARY")
print("=" * 70)
print(f"\n{'Module':25s} {'FP32 (MB)':>12s} {'INT8 (MB)':>12s} {'Reduction':>10s}")
print("-" * 60)
total_fp32 = 0; total_int8 = 0
for name, r in results.items():
fp32 = r.get("fp32", 0) or 0
int8 = r.get("int8")
total_fp32 += fp32
if int8 is not None:
total_int8 += int8
red = f"{fp32/int8:.1f}x" if int8 > 0 else "∞"
else:
red = f"FAILED: {r.get('error','')[:40]}"
int8 = 0
print(f" {name:23s} {fp32:10.1f} {int8:10.1f} {red}")
print("-" * 60)
ovr = f"{total_fp32/total_int8:.1f}x" if total_int8 > 0 else "N/A"
print(f" {'TOTAL':23s} {total_fp32:10.1f} {total_int8:10.1f} {ovr}")
print("\nPhase 6 complete!")