acul3's picture
Upload scripts/export_vocoder.py with huggingface_hub
b3779d4 verified
#!/usr/bin/env python3
"""
Phase 5: Export Speech Tokenizer Decoder (Vocoder) to ExecuTorch .pte
======================================================================
The vocoder converts codec tokens β†’ audio waveform.
Architecture:
codes [B, 16, T] β†’ VQ decode β†’ [B, codebook_dim, T]
β†’ Conv1d β†’ Transformer (8 layers) β†’ Conv1d
β†’ Upsample (2x, 2x) via ConvTranspose1d + ConvNeXt
β†’ Decoder (8x, 5x, 4x, 3x) via ConvTranspose1d + SnakeBeta + ResBlocks
β†’ Conv1d β†’ waveform [B, 1, T*1920]
Total upsample: 2*2*8*5*4*3 = 3840x (but code downsample is 1920x, so net 1920x)
Wait β€” the decoder forward uses total_upsample which is upsample_rates * upsampling_ratios.
"""
import sys
import os
import copy
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Fixed code length for export (50 frames β‰ˆ 4 seconds of audio)
FIXED_CODE_LEN = 50
NUM_QUANTIZERS = 16
print("=" * 70)
print("PHASE 5: Export Vocoder (Speech Tokenizer Decoder) β†’ .pte")
print("=" * 70)
# ── 1. Load Model ───────────────────────────────────────────────────
print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
# ── 2. Create Vocoder Wrapper ────────────────────────────────────────
print("\n[2/5] Creating vocoder wrapper...")
# The decoder has dynamic padding calculations that depend on input length.
# With a FIXED input length, these become constants. We wrap the original
# decoder directly and let torch.export trace through the fixed-size logic.
class VocoderForExport(nn.Module):
"""
Wraps the speech tokenizer decoder for export.
Bypasses chunked_decode and calls forward() directly.
Input: codes [1, num_quantizers, code_len] β€” all int64
Output: waveform [1, 1, code_len * decode_upsample_rate]
"""
def __init__(self, original_decoder):
super().__init__()
self.decoder = copy.deepcopy(original_decoder)
def forward(self, codes: torch.Tensor) -> torch.Tensor:
"""
Args:
codes: [1, 16, FIXED_CODE_LEN] β€” LongTensor of codec indices
Returns:
waveform: [1, 1, FIXED_CODE_LEN * upsample] β€” float waveform in [-1, 1]
"""
return self.decoder(codes)
vocoder = VocoderForExport(model.speech_tokenizer.model.decoder)
vocoder.eval()
param_count = sum(p.numel() for p in vocoder.parameters())
print(f" Vocoder parameters: {param_count / 1e6:.1f}M")
# ── 3. Validate ─────────────────────────────────────────────────────
print("\n[3/5] Validating vocoder wrapper...")
test_codes = torch.randint(0, 2048, (1, NUM_QUANTIZERS, FIXED_CODE_LEN))
with torch.no_grad():
# Test original decoder
orig_wav = model.speech_tokenizer.model.decoder(test_codes)
# Test wrapper
wrap_wav = vocoder(test_codes)
print(f" Input codes shape: {list(test_codes.shape)}")
print(f" Original output shape: {list(orig_wav.shape)}")
print(f" Wrapper output shape: {list(wrap_wav.shape)}")
cos_sim = F.cosine_similarity(orig_wav.flatten().unsqueeze(0),
wrap_wav.flatten().unsqueeze(0)).item()
max_diff = (orig_wav - wrap_wav).abs().max().item()
print(f" Cosine similarity: {cos_sim:.6f}")
print(f" Max abs difference: {max_diff:.2e}")
assert cos_sim > 0.999, f"Mismatch! cos_sim={cos_sim}"
print(" PASS β€” vocoder validated")
upsample_rate = wrap_wav.shape[-1] // FIXED_CODE_LEN
print(f" Upsample rate: {upsample_rate}x")
print(f" Output duration: {wrap_wav.shape[-1] / 24000:.1f}s at 24kHz")
# ── 4. torch.export ─────────────────────────────────────────────────
print("\n[4/5] Running torch.export...")
t0 = time.time()
example_input = (test_codes,)
try:
exported = torch.export.export(
vocoder,
example_input,
strict=False,
)
print(f" torch.export succeeded in {time.time() - t0:.1f}s")
print(f" Graph nodes: {len(exported.graph.nodes)}")
except Exception as e:
print(f" torch.export FAILED: {e}")
exported = None
# ── 5. Lower to .pte ────────────────────────────────────────────────
print("\n[5/5] Lowering to ExecuTorch .pte...")
t0 = time.time()
if exported is not None:
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(OUTPUT_DIR, "vocoder.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size = os.path.getsize(pte_path) / 1e6
print(f" .pte saved: {pte_path}")
print(f" .pte size: {pte_size:.1f} MB")
print(f" Lowered in {time.time() - t0:.1f}s")
except Exception as e:
print(f" ExecuTorch lowering failed: {e}")
if exported is not None:
pt2_path = os.path.join(OUTPUT_DIR, "vocoder.pt2")
torch.export.save(exported, pt2_path)
print(f" Saved exported program: {pt2_path}")
# Validate .pte
if os.path.exists(os.path.join(OUTPUT_DIR, "vocoder.pte")):
print("\n Validating .pte execution...")
try:
from executorch.runtime import Runtime
runtime = Runtime.get()
program = runtime.load_program(
open(os.path.join(OUTPUT_DIR, "vocoder.pte"), "rb").read()
)
method = program.load_method("forward")
pte_out = method.execute([test_codes])
if isinstance(pte_out, (list, tuple)):
pte_out = pte_out[0]
with torch.no_grad():
ref_out = vocoder(test_codes)
cos_pte = F.cosine_similarity(
ref_out.flatten().unsqueeze(0),
pte_out.flatten().unsqueeze(0)
).item()
print(f" .pte vs PyTorch cosine sim: {cos_pte:.6f}")
except Exception as e:
print(f" .pte validation: {e}")
else:
print(" No exported program to lower.")
# Save state dict as fallback
torch.save(vocoder.state_dict(), os.path.join(OUTPUT_DIR, "vocoder_state_dict.pt"))
print(f" Saved state dict: {OUTPUT_DIR}/vocoder_state_dict.pt")
print("\n" + "=" * 70)
print("Phase 5 complete!")
print(f" Fixed code length: {FIXED_CODE_LEN} frames")
print(f" Output: {FIXED_CODE_LEN * upsample_rate} samples ({FIXED_CODE_LEN * upsample_rate / 24000:.1f}s)")
print("=" * 70)