File size: 8,016 Bytes
b3779d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | #!/usr/bin/env python3
"""
Phase 5: Export Speech Tokenizer Decoder (Vocoder) to ExecuTorch .pte
======================================================================
The vocoder converts codec tokens β audio waveform.
Architecture:
codes [B, 16, T] β VQ decode β [B, codebook_dim, T]
β Conv1d β Transformer (8 layers) β Conv1d
β Upsample (2x, 2x) via ConvTranspose1d + ConvNeXt
β Decoder (8x, 5x, 4x, 3x) via ConvTranspose1d + SnakeBeta + ResBlocks
β Conv1d β waveform [B, 1, T*1920]
Total upsample: 2*2*8*5*4*3 = 3840x (but code downsample is 1920x, so net 1920x)
Wait β the decoder forward uses total_upsample which is upsample_rates * upsampling_ratios.
"""
import sys
import os
import copy
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Fixed code length for export (50 frames β 4 seconds of audio)
FIXED_CODE_LEN = 50
NUM_QUANTIZERS = 16
print("=" * 70)
print("PHASE 5: Export Vocoder (Speech Tokenizer Decoder) β .pte")
print("=" * 70)
# ββ 1. Load Model βββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
# ββ 2. Create Vocoder Wrapper ββββββββββββββββββββββββββββββββββββββββ
print("\n[2/5] Creating vocoder wrapper...")
# The decoder has dynamic padding calculations that depend on input length.
# With a FIXED input length, these become constants. We wrap the original
# decoder directly and let torch.export trace through the fixed-size logic.
class VocoderForExport(nn.Module):
"""
Wraps the speech tokenizer decoder for export.
Bypasses chunked_decode and calls forward() directly.
Input: codes [1, num_quantizers, code_len] β all int64
Output: waveform [1, 1, code_len * decode_upsample_rate]
"""
def __init__(self, original_decoder):
super().__init__()
self.decoder = copy.deepcopy(original_decoder)
def forward(self, codes: torch.Tensor) -> torch.Tensor:
"""
Args:
codes: [1, 16, FIXED_CODE_LEN] β LongTensor of codec indices
Returns:
waveform: [1, 1, FIXED_CODE_LEN * upsample] β float waveform in [-1, 1]
"""
return self.decoder(codes)
vocoder = VocoderForExport(model.speech_tokenizer.model.decoder)
vocoder.eval()
param_count = sum(p.numel() for p in vocoder.parameters())
print(f" Vocoder parameters: {param_count / 1e6:.1f}M")
# ββ 3. Validate βββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[3/5] Validating vocoder wrapper...")
test_codes = torch.randint(0, 2048, (1, NUM_QUANTIZERS, FIXED_CODE_LEN))
with torch.no_grad():
# Test original decoder
orig_wav = model.speech_tokenizer.model.decoder(test_codes)
# Test wrapper
wrap_wav = vocoder(test_codes)
print(f" Input codes shape: {list(test_codes.shape)}")
print(f" Original output shape: {list(orig_wav.shape)}")
print(f" Wrapper output shape: {list(wrap_wav.shape)}")
cos_sim = F.cosine_similarity(orig_wav.flatten().unsqueeze(0),
wrap_wav.flatten().unsqueeze(0)).item()
max_diff = (orig_wav - wrap_wav).abs().max().item()
print(f" Cosine similarity: {cos_sim:.6f}")
print(f" Max abs difference: {max_diff:.2e}")
assert cos_sim > 0.999, f"Mismatch! cos_sim={cos_sim}"
print(" PASS β vocoder validated")
upsample_rate = wrap_wav.shape[-1] // FIXED_CODE_LEN
print(f" Upsample rate: {upsample_rate}x")
print(f" Output duration: {wrap_wav.shape[-1] / 24000:.1f}s at 24kHz")
# ββ 4. torch.export βββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[4/5] Running torch.export...")
t0 = time.time()
example_input = (test_codes,)
try:
exported = torch.export.export(
vocoder,
example_input,
strict=False,
)
print(f" torch.export succeeded in {time.time() - t0:.1f}s")
print(f" Graph nodes: {len(exported.graph.nodes)}")
except Exception as e:
print(f" torch.export FAILED: {e}")
exported = None
# ββ 5. Lower to .pte ββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[5/5] Lowering to ExecuTorch .pte...")
t0 = time.time()
if exported is not None:
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(OUTPUT_DIR, "vocoder.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size = os.path.getsize(pte_path) / 1e6
print(f" .pte saved: {pte_path}")
print(f" .pte size: {pte_size:.1f} MB")
print(f" Lowered in {time.time() - t0:.1f}s")
except Exception as e:
print(f" ExecuTorch lowering failed: {e}")
if exported is not None:
pt2_path = os.path.join(OUTPUT_DIR, "vocoder.pt2")
torch.export.save(exported, pt2_path)
print(f" Saved exported program: {pt2_path}")
# Validate .pte
if os.path.exists(os.path.join(OUTPUT_DIR, "vocoder.pte")):
print("\n Validating .pte execution...")
try:
from executorch.runtime import Runtime
runtime = Runtime.get()
program = runtime.load_program(
open(os.path.join(OUTPUT_DIR, "vocoder.pte"), "rb").read()
)
method = program.load_method("forward")
pte_out = method.execute([test_codes])
if isinstance(pte_out, (list, tuple)):
pte_out = pte_out[0]
with torch.no_grad():
ref_out = vocoder(test_codes)
cos_pte = F.cosine_similarity(
ref_out.flatten().unsqueeze(0),
pte_out.flatten().unsqueeze(0)
).item()
print(f" .pte vs PyTorch cosine sim: {cos_pte:.6f}")
except Exception as e:
print(f" .pte validation: {e}")
else:
print(" No exported program to lower.")
# Save state dict as fallback
torch.save(vocoder.state_dict(), os.path.join(OUTPUT_DIR, "vocoder_state_dict.pt"))
print(f" Saved state dict: {OUTPUT_DIR}/vocoder_state_dict.pt")
print("\n" + "=" * 70)
print("Phase 5 complete!")
print(f" Fixed code length: {FIXED_CODE_LEN} frames")
print(f" Output: {FIXED_CODE_LEN * upsample_rate} samples ({FIXED_CODE_LEN * upsample_rate / 24000:.1f}s)")
print("=" * 70)
|