File size: 8,016 Bytes
b3779d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#!/usr/bin/env python3
"""
Phase 5: Export Speech Tokenizer Decoder (Vocoder) to ExecuTorch .pte
======================================================================
The vocoder converts codec tokens β†’ audio waveform.

Architecture:
  codes [B, 16, T] β†’ VQ decode β†’ [B, codebook_dim, T]
    β†’ Conv1d β†’ Transformer (8 layers) β†’ Conv1d
    β†’ Upsample (2x, 2x) via ConvTranspose1d + ConvNeXt
    β†’ Decoder (8x, 5x, 4x, 3x) via ConvTranspose1d + SnakeBeta + ResBlocks
    β†’ Conv1d β†’ waveform [B, 1, T*1920]

Total upsample: 2*2*8*5*4*3 = 3840x (but code downsample is 1920x, so net 1920x)
Wait β€” the decoder forward uses total_upsample which is upsample_rates * upsampling_ratios.
"""

import sys
import os
import copy
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")

if VENV_SITE not in sys.path:
    sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
    sys.path.insert(0, QWEN_TTS_SRC)

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Fixed code length for export (50 frames β‰ˆ 4 seconds of audio)
FIXED_CODE_LEN = 50
NUM_QUANTIZERS = 16

print("=" * 70)
print("PHASE 5: Export Vocoder (Speech Tokenizer Decoder) β†’ .pte")
print("=" * 70)

# ── 1. Load Model ───────────────────────────────────────────────────

print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration

config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
    MODEL_PATH, config=config, dtype=torch.float32,
    attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print("  Model loaded.")

# ── 2. Create Vocoder Wrapper ────────────────────────────────────────

print("\n[2/5] Creating vocoder wrapper...")

# The decoder has dynamic padding calculations that depend on input length.
# With a FIXED input length, these become constants. We wrap the original
# decoder directly and let torch.export trace through the fixed-size logic.

class VocoderForExport(nn.Module):
    """
    Wraps the speech tokenizer decoder for export.

    Bypasses chunked_decode and calls forward() directly.
    Input:  codes [1, num_quantizers, code_len] β€” all int64
    Output: waveform [1, 1, code_len * decode_upsample_rate]
    """

    def __init__(self, original_decoder):
        super().__init__()
        self.decoder = copy.deepcopy(original_decoder)

    def forward(self, codes: torch.Tensor) -> torch.Tensor:
        """
        Args:
            codes: [1, 16, FIXED_CODE_LEN] β€” LongTensor of codec indices
        Returns:
            waveform: [1, 1, FIXED_CODE_LEN * upsample] β€” float waveform in [-1, 1]
        """
        return self.decoder(codes)


vocoder = VocoderForExport(model.speech_tokenizer.model.decoder)
vocoder.eval()

param_count = sum(p.numel() for p in vocoder.parameters())
print(f"  Vocoder parameters: {param_count / 1e6:.1f}M")

# ── 3. Validate ─────────────────────────────────────────────────────

print("\n[3/5] Validating vocoder wrapper...")

test_codes = torch.randint(0, 2048, (1, NUM_QUANTIZERS, FIXED_CODE_LEN))

with torch.no_grad():
    # Test original decoder
    orig_wav = model.speech_tokenizer.model.decoder(test_codes)
    # Test wrapper
    wrap_wav = vocoder(test_codes)

print(f"  Input codes shape:    {list(test_codes.shape)}")
print(f"  Original output shape: {list(orig_wav.shape)}")
print(f"  Wrapper output shape:  {list(wrap_wav.shape)}")

cos_sim = F.cosine_similarity(orig_wav.flatten().unsqueeze(0),
                               wrap_wav.flatten().unsqueeze(0)).item()
max_diff = (orig_wav - wrap_wav).abs().max().item()
print(f"  Cosine similarity:    {cos_sim:.6f}")
print(f"  Max abs difference:   {max_diff:.2e}")
assert cos_sim > 0.999, f"Mismatch! cos_sim={cos_sim}"
print("  PASS β€” vocoder validated")

upsample_rate = wrap_wav.shape[-1] // FIXED_CODE_LEN
print(f"  Upsample rate: {upsample_rate}x")
print(f"  Output duration: {wrap_wav.shape[-1] / 24000:.1f}s at 24kHz")

# ── 4. torch.export ─────────────────────────────────────────────────

print("\n[4/5] Running torch.export...")
t0 = time.time()

example_input = (test_codes,)

try:
    exported = torch.export.export(
        vocoder,
        example_input,
        strict=False,
    )
    print(f"  torch.export succeeded in {time.time() - t0:.1f}s")
    print(f"  Graph nodes: {len(exported.graph.nodes)}")
except Exception as e:
    print(f"  torch.export FAILED: {e}")
    exported = None

# ── 5. Lower to .pte ────────────────────────────────────────────────

print("\n[5/5] Lowering to ExecuTorch .pte...")
t0 = time.time()

if exported is not None:
    try:
        from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
        from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

        edge = to_edge_transform_and_lower(
            exported,
            compile_config=EdgeCompileConfig(_check_ir_validity=False),
            partitioner=[XnnpackPartitioner()],
        )
        et_program = edge.to_executorch()

        pte_path = os.path.join(OUTPUT_DIR, "vocoder.pte")
        with open(pte_path, "wb") as f:
            f.write(et_program.buffer)

        pte_size = os.path.getsize(pte_path) / 1e6
        print(f"  .pte saved: {pte_path}")
        print(f"  .pte size:  {pte_size:.1f} MB")
        print(f"  Lowered in {time.time() - t0:.1f}s")

    except Exception as e:
        print(f"  ExecuTorch lowering failed: {e}")
        if exported is not None:
            pt2_path = os.path.join(OUTPUT_DIR, "vocoder.pt2")
            torch.export.save(exported, pt2_path)
            print(f"  Saved exported program: {pt2_path}")

    # Validate .pte
    if os.path.exists(os.path.join(OUTPUT_DIR, "vocoder.pte")):
        print("\n  Validating .pte execution...")
        try:
            from executorch.runtime import Runtime

            runtime = Runtime.get()
            program = runtime.load_program(
                open(os.path.join(OUTPUT_DIR, "vocoder.pte"), "rb").read()
            )
            method = program.load_method("forward")
            pte_out = method.execute([test_codes])
            if isinstance(pte_out, (list, tuple)):
                pte_out = pte_out[0]
            with torch.no_grad():
                ref_out = vocoder(test_codes)
            cos_pte = F.cosine_similarity(
                ref_out.flatten().unsqueeze(0),
                pte_out.flatten().unsqueeze(0)
            ).item()
            print(f"  .pte vs PyTorch cosine sim: {cos_pte:.6f}")
        except Exception as e:
            print(f"  .pte validation: {e}")
else:
    print("  No exported program to lower.")
    # Save state dict as fallback
    torch.save(vocoder.state_dict(), os.path.join(OUTPUT_DIR, "vocoder_state_dict.pt"))
    print(f"  Saved state dict: {OUTPUT_DIR}/vocoder_state_dict.pt")

print("\n" + "=" * 70)
print("Phase 5 complete!")
print(f"  Fixed code length: {FIXED_CODE_LEN} frames")
print(f"  Output: {FIXED_CODE_LEN * upsample_rate} samples ({FIXED_CODE_LEN * upsample_rate / 24000:.1f}s)")
print("=" * 70)