#!/usr/bin/env python3 """ Phase 5: Export Speech Tokenizer Decoder (Vocoder) to ExecuTorch .pte ====================================================================== The vocoder converts codec tokens → audio waveform. Architecture: codes [B, 16, T] → VQ decode → [B, codebook_dim, T] → Conv1d → Transformer (8 layers) → Conv1d → Upsample (2x, 2x) via ConvTranspose1d + ConvNeXt → Decoder (8x, 5x, 4x, 3x) via ConvTranspose1d + SnakeBeta + ResBlocks → Conv1d → waveform [B, 1, T*1920] Total upsample: 2*2*8*5*4*3 = 3840x (but code downsample is 1920x, so net 1920x) Wait — the decoder forward uses total_upsample which is upsample_rates * upsampling_ratios. """ import sys import os import copy import time import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") if VENV_SITE not in sys.path: sys.path.insert(0, VENV_SITE) if QWEN_TTS_SRC not in sys.path: sys.path.insert(0, QWEN_TTS_SRC) os.makedirs(OUTPUT_DIR, exist_ok=True) # Fixed code length for export (50 frames ≈ 4 seconds of audio) FIXED_CODE_LEN = 50 NUM_QUANTIZERS = 16 print("=" * 70) print("PHASE 5: Export Vocoder (Speech Tokenizer Decoder) → .pte") print("=" * 70) # ── 1. Load Model ─────────────────────────────────────────────────── print("\n[1/5] Loading model...") from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) model = Qwen3TTSForConditionalGeneration.from_pretrained( MODEL_PATH, config=config, dtype=torch.float32, attn_implementation="sdpa", device_map="cpu", ) model.eval() print(" Model loaded.") # ── 2. Create Vocoder Wrapper ──────────────────────────────────────── print("\n[2/5] Creating vocoder wrapper...") # The decoder has dynamic padding calculations that depend on input length. # With a FIXED input length, these become constants. We wrap the original # decoder directly and let torch.export trace through the fixed-size logic. class VocoderForExport(nn.Module): """ Wraps the speech tokenizer decoder for export. Bypasses chunked_decode and calls forward() directly. Input: codes [1, num_quantizers, code_len] — all int64 Output: waveform [1, 1, code_len * decode_upsample_rate] """ def __init__(self, original_decoder): super().__init__() self.decoder = copy.deepcopy(original_decoder) def forward(self, codes: torch.Tensor) -> torch.Tensor: """ Args: codes: [1, 16, FIXED_CODE_LEN] — LongTensor of codec indices Returns: waveform: [1, 1, FIXED_CODE_LEN * upsample] — float waveform in [-1, 1] """ return self.decoder(codes) vocoder = VocoderForExport(model.speech_tokenizer.model.decoder) vocoder.eval() param_count = sum(p.numel() for p in vocoder.parameters()) print(f" Vocoder parameters: {param_count / 1e6:.1f}M") # ── 3. Validate ───────────────────────────────────────────────────── print("\n[3/5] Validating vocoder wrapper...") test_codes = torch.randint(0, 2048, (1, NUM_QUANTIZERS, FIXED_CODE_LEN)) with torch.no_grad(): # Test original decoder orig_wav = model.speech_tokenizer.model.decoder(test_codes) # Test wrapper wrap_wav = vocoder(test_codes) print(f" Input codes shape: {list(test_codes.shape)}") print(f" Original output shape: {list(orig_wav.shape)}") print(f" Wrapper output shape: {list(wrap_wav.shape)}") cos_sim = F.cosine_similarity(orig_wav.flatten().unsqueeze(0), wrap_wav.flatten().unsqueeze(0)).item() max_diff = (orig_wav - wrap_wav).abs().max().item() print(f" Cosine similarity: {cos_sim:.6f}") print(f" Max abs difference: {max_diff:.2e}") assert cos_sim > 0.999, f"Mismatch! cos_sim={cos_sim}" print(" PASS — vocoder validated") upsample_rate = wrap_wav.shape[-1] // FIXED_CODE_LEN print(f" Upsample rate: {upsample_rate}x") print(f" Output duration: {wrap_wav.shape[-1] / 24000:.1f}s at 24kHz") # ── 4. torch.export ───────────────────────────────────────────────── print("\n[4/5] Running torch.export...") t0 = time.time() example_input = (test_codes,) try: exported = torch.export.export( vocoder, example_input, strict=False, ) print(f" torch.export succeeded in {time.time() - t0:.1f}s") print(f" Graph nodes: {len(exported.graph.nodes)}") except Exception as e: print(f" torch.export FAILED: {e}") exported = None # ── 5. Lower to .pte ──────────────────────────────────────────────── print("\n[5/5] Lowering to ExecuTorch .pte...") t0 = time.time() if exported is not None: try: from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner edge = to_edge_transform_and_lower( exported, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) et_program = edge.to_executorch() pte_path = os.path.join(OUTPUT_DIR, "vocoder.pte") with open(pte_path, "wb") as f: f.write(et_program.buffer) pte_size = os.path.getsize(pte_path) / 1e6 print(f" .pte saved: {pte_path}") print(f" .pte size: {pte_size:.1f} MB") print(f" Lowered in {time.time() - t0:.1f}s") except Exception as e: print(f" ExecuTorch lowering failed: {e}") if exported is not None: pt2_path = os.path.join(OUTPUT_DIR, "vocoder.pt2") torch.export.save(exported, pt2_path) print(f" Saved exported program: {pt2_path}") # Validate .pte if os.path.exists(os.path.join(OUTPUT_DIR, "vocoder.pte")): print("\n Validating .pte execution...") try: from executorch.runtime import Runtime runtime = Runtime.get() program = runtime.load_program( open(os.path.join(OUTPUT_DIR, "vocoder.pte"), "rb").read() ) method = program.load_method("forward") pte_out = method.execute([test_codes]) if isinstance(pte_out, (list, tuple)): pte_out = pte_out[0] with torch.no_grad(): ref_out = vocoder(test_codes) cos_pte = F.cosine_similarity( ref_out.flatten().unsqueeze(0), pte_out.flatten().unsqueeze(0) ).item() print(f" .pte vs PyTorch cosine sim: {cos_pte:.6f}") except Exception as e: print(f" .pte validation: {e}") else: print(" No exported program to lower.") # Save state dict as fallback torch.save(vocoder.state_dict(), os.path.join(OUTPUT_DIR, "vocoder_state_dict.pt")) print(f" Saved state dict: {OUTPUT_DIR}/vocoder_state_dict.pt") print("\n" + "=" * 70) print("Phase 5 complete!") print(f" Fixed code length: {FIXED_CODE_LEN} frames") print(f" Output: {FIXED_CODE_LEN * upsample_rate} samples ({FIXED_CODE_LEN * upsample_rate / 24000:.1f}s)") print("=" * 70)