| |
| """ |
| Phase 5: Export Speech Tokenizer Decoder (Vocoder) to ExecuTorch .pte |
| ====================================================================== |
| The vocoder converts codec tokens β audio waveform. |
| |
| Architecture: |
| codes [B, 16, T] β VQ decode β [B, codebook_dim, T] |
| β Conv1d β Transformer (8 layers) β Conv1d |
| β Upsample (2x, 2x) via ConvTranspose1d + ConvNeXt |
| β Decoder (8x, 5x, 4x, 3x) via ConvTranspose1d + SnakeBeta + ResBlocks |
| β Conv1d β waveform [B, 1, T*1920] |
| |
| Total upsample: 2*2*8*5*4*3 = 3840x (but code downsample is 1920x, so net 1920x) |
| Wait β the decoder forward uses total_upsample which is upsample_rates * upsampling_ratios. |
| """ |
|
|
| import sys |
| import os |
| import copy |
| import time |
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") |
| VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") |
| QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") |
| OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") |
|
|
| if VENV_SITE not in sys.path: |
| sys.path.insert(0, VENV_SITE) |
| if QWEN_TTS_SRC not in sys.path: |
| sys.path.insert(0, QWEN_TTS_SRC) |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| FIXED_CODE_LEN = 50 |
| NUM_QUANTIZERS = 16 |
|
|
| print("=" * 70) |
| print("PHASE 5: Export Vocoder (Speech Tokenizer Decoder) β .pte") |
| print("=" * 70) |
|
|
| |
|
|
| print("\n[1/5] Loading model...") |
| from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig |
| from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration |
|
|
| config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) |
| model = Qwen3TTSForConditionalGeneration.from_pretrained( |
| MODEL_PATH, config=config, dtype=torch.float32, |
| attn_implementation="sdpa", device_map="cpu", |
| ) |
| model.eval() |
| print(" Model loaded.") |
|
|
| |
|
|
| print("\n[2/5] Creating vocoder wrapper...") |
|
|
| |
| |
| |
|
|
| class VocoderForExport(nn.Module): |
| """ |
| Wraps the speech tokenizer decoder for export. |
| |
| Bypasses chunked_decode and calls forward() directly. |
| Input: codes [1, num_quantizers, code_len] β all int64 |
| Output: waveform [1, 1, code_len * decode_upsample_rate] |
| """ |
|
|
| def __init__(self, original_decoder): |
| super().__init__() |
| self.decoder = copy.deepcopy(original_decoder) |
|
|
| def forward(self, codes: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| codes: [1, 16, FIXED_CODE_LEN] β LongTensor of codec indices |
| Returns: |
| waveform: [1, 1, FIXED_CODE_LEN * upsample] β float waveform in [-1, 1] |
| """ |
| return self.decoder(codes) |
|
|
|
|
| vocoder = VocoderForExport(model.speech_tokenizer.model.decoder) |
| vocoder.eval() |
|
|
| param_count = sum(p.numel() for p in vocoder.parameters()) |
| print(f" Vocoder parameters: {param_count / 1e6:.1f}M") |
|
|
| |
|
|
| print("\n[3/5] Validating vocoder wrapper...") |
|
|
| test_codes = torch.randint(0, 2048, (1, NUM_QUANTIZERS, FIXED_CODE_LEN)) |
|
|
| with torch.no_grad(): |
| |
| orig_wav = model.speech_tokenizer.model.decoder(test_codes) |
| |
| wrap_wav = vocoder(test_codes) |
|
|
| print(f" Input codes shape: {list(test_codes.shape)}") |
| print(f" Original output shape: {list(orig_wav.shape)}") |
| print(f" Wrapper output shape: {list(wrap_wav.shape)}") |
|
|
| cos_sim = F.cosine_similarity(orig_wav.flatten().unsqueeze(0), |
| wrap_wav.flatten().unsqueeze(0)).item() |
| max_diff = (orig_wav - wrap_wav).abs().max().item() |
| print(f" Cosine similarity: {cos_sim:.6f}") |
| print(f" Max abs difference: {max_diff:.2e}") |
| assert cos_sim > 0.999, f"Mismatch! cos_sim={cos_sim}" |
| print(" PASS β vocoder validated") |
|
|
| upsample_rate = wrap_wav.shape[-1] // FIXED_CODE_LEN |
| print(f" Upsample rate: {upsample_rate}x") |
| print(f" Output duration: {wrap_wav.shape[-1] / 24000:.1f}s at 24kHz") |
|
|
| |
|
|
| print("\n[4/5] Running torch.export...") |
| t0 = time.time() |
|
|
| example_input = (test_codes,) |
|
|
| try: |
| exported = torch.export.export( |
| vocoder, |
| example_input, |
| strict=False, |
| ) |
| print(f" torch.export succeeded in {time.time() - t0:.1f}s") |
| print(f" Graph nodes: {len(exported.graph.nodes)}") |
| except Exception as e: |
| print(f" torch.export FAILED: {e}") |
| exported = None |
|
|
| |
|
|
| print("\n[5/5] Lowering to ExecuTorch .pte...") |
| t0 = time.time() |
|
|
| if exported is not None: |
| try: |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| edge = to_edge_transform_and_lower( |
| exported, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
| et_program = edge.to_executorch() |
|
|
| pte_path = os.path.join(OUTPUT_DIR, "vocoder.pte") |
| with open(pte_path, "wb") as f: |
| f.write(et_program.buffer) |
|
|
| pte_size = os.path.getsize(pte_path) / 1e6 |
| print(f" .pte saved: {pte_path}") |
| print(f" .pte size: {pte_size:.1f} MB") |
| print(f" Lowered in {time.time() - t0:.1f}s") |
|
|
| except Exception as e: |
| print(f" ExecuTorch lowering failed: {e}") |
| if exported is not None: |
| pt2_path = os.path.join(OUTPUT_DIR, "vocoder.pt2") |
| torch.export.save(exported, pt2_path) |
| print(f" Saved exported program: {pt2_path}") |
|
|
| |
| if os.path.exists(os.path.join(OUTPUT_DIR, "vocoder.pte")): |
| print("\n Validating .pte execution...") |
| try: |
| from executorch.runtime import Runtime |
|
|
| runtime = Runtime.get() |
| program = runtime.load_program( |
| open(os.path.join(OUTPUT_DIR, "vocoder.pte"), "rb").read() |
| ) |
| method = program.load_method("forward") |
| pte_out = method.execute([test_codes]) |
| if isinstance(pte_out, (list, tuple)): |
| pte_out = pte_out[0] |
| with torch.no_grad(): |
| ref_out = vocoder(test_codes) |
| cos_pte = F.cosine_similarity( |
| ref_out.flatten().unsqueeze(0), |
| pte_out.flatten().unsqueeze(0) |
| ).item() |
| print(f" .pte vs PyTorch cosine sim: {cos_pte:.6f}") |
| except Exception as e: |
| print(f" .pte validation: {e}") |
| else: |
| print(" No exported program to lower.") |
| |
| torch.save(vocoder.state_dict(), os.path.join(OUTPUT_DIR, "vocoder_state_dict.pt")) |
| print(f" Saved state dict: {OUTPUT_DIR}/vocoder_state_dict.pt") |
|
|
| print("\n" + "=" * 70) |
| print("Phase 5 complete!") |
| print(f" Fixed code length: {FIXED_CODE_LEN} frames") |
| print(f" Output: {FIXED_CODE_LEN * upsample_rate} samples ({FIXED_CODE_LEN * upsample_rate / 24000:.1f}s)") |
| print("=" * 70) |
|
|