File size: 16,790 Bytes
40ba644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
#!/usr/bin/env python3
"""
Phase 1: Deep Architecture Analysis of Qwen3-TTS for ExecuTorch Export
======================================================================
Loads the model, maps all modules with parameter counts, traces a real
voice-clone inference to capture shapes, and identifies export blockers.
"""

import sys
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn

# ── paths ────────────────────────────────────────────────────────────
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")

# Ensure the venv's site-packages is on the path so qwen_tts can be imported
if VENV_SITE not in sys.path:
    sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
    sys.path.insert(0, QWEN_TTS_SRC)

# ── helpers ──────────────────────────────────────────────────────────

def count_params(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters())

def fmt(n: int) -> str:
    if n >= 1e9:
        return f"{n / 1e9:.1f}B"
    if n >= 1e6:
        return f"{n / 1e6:.1f}M"
    if n >= 1e3:
        return f"{n / 1e3:.1f}K"
    return str(n)

def param_table(module: nn.Module, prefix: str = "", depth: int = 0, max_depth: int = 3):
    """Print a hierarchical parameter table."""
    total = count_params(module)
    indent = "  " * depth
    name = prefix or module.__class__.__name__
    print(f"{indent}{name}: {fmt(total)} params")
    if depth < max_depth:
        for child_name, child in module.named_children():
            child_prefix = f"{prefix}.{child_name}" if prefix else child_name
            param_table(child, child_prefix, depth + 1, max_depth)


# ── 1. Load Model ───────────────────────────────────────────────────

print("=" * 70)
print("PHASE 1: Deep Architecture Analysis β€” Qwen3-TTS 1.7B-Base")
print("=" * 70)

print("\n[1/5] Loading model from", MODEL_PATH)
t0 = time.time()

from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import (
    Qwen3TTSForConditionalGeneration,
    mel_spectrogram,
)

config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
# Force SDPA attention for exportability
model = Qwen3TTSForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    config=config,
    torch_dtype=torch.float32,
    attn_implementation="sdpa",
    device_map="cpu",
)
model.eval()
print(f"  Loaded in {time.time() - t0:.1f}s")

# ── 2. Parameter Map ────────────────────────────────────────────────

print("\n[2/5] Parameter Map (hierarchical)")
print("-" * 60)

param_table(model, "Qwen3TTSForConditionalGeneration", max_depth=4)

print("\n--- Top-level component sizes ---")
components = {
    "speaker_encoder": model.speaker_encoder,
    "talker": model.talker,
    "talker.model": model.talker.model,
    "talker.text_projection": model.talker.text_projection,
    "talker.codec_head": model.talker.codec_head,
    "talker.code_predictor": model.talker.code_predictor,
}
for name, mod in components.items():
    print(f"  {name:40s}: {fmt(count_params(mod)):>8s} params")

if model.speech_tokenizer is not None and hasattr(model.speech_tokenizer, 'model'):
    st = model.speech_tokenizer.model  # Qwen3TTSTokenizerV2Model (nn.Module)
    print(f"  {'speech_tokenizer.model':40s}: {fmt(count_params(st)):>8s} params")
    if hasattr(st, 'encoder'):
        print(f"  {'speech_tokenizer.model.encoder':40s}: {fmt(count_params(st.encoder)):>8s} params")
    if hasattr(st, 'decoder'):
        print(f"  {'speech_tokenizer.model.decoder':40s}: {fmt(count_params(st.decoder)):>8s} params")

# ── 3. Config Summary ───────────────────────────────────────────────

print("\n[3/5] Key Config Values")
print("-" * 60)

tc = config.talker_config
cpc = tc.code_predictor_config
sec = config.speaker_encoder_config

info = {
    "Speaker Encoder": {
        "mel_dim": sec.mel_dim,
        "enc_dim (output)": sec.enc_dim,
        "enc_channels": sec.enc_channels,
        "sample_rate": sec.sample_rate,
    },
    "Talker (Main LM)": {
        "hidden_size": tc.hidden_size,
        "num_hidden_layers": tc.num_hidden_layers,
        "num_attention_heads": tc.num_attention_heads,
        "num_key_value_heads": tc.num_key_value_heads,
        "head_dim": tc.head_dim,
        "intermediate_size": tc.intermediate_size,
        "text_vocab_size": tc.text_vocab_size,
        "codec_vocab_size": tc.vocab_size,
        "num_code_groups": tc.num_code_groups,
        "max_position_embeddings": tc.max_position_embeddings,
        "rope_scaling": tc.rope_scaling,
    },
    "Code Predictor": {
        "hidden_size": cpc.hidden_size,
        "num_hidden_layers": cpc.num_hidden_layers,
        "num_attention_heads": cpc.num_attention_heads,
        "num_key_value_heads": cpc.num_key_value_heads,
        "num_code_groups": cpc.num_code_groups,
        "vocab_size": cpc.vocab_size,
    },
}

for section, kvs in info.items():
    print(f"\n  {section}:")
    for k, v in kvs.items():
        print(f"    {k:35s}: {v}")

# ── 4. Trace Real Inference ─────────────────────────────────────────

print("\n[4/5] Tracing Real Voice-Clone Inference")
print("-" * 60)

# Create synthetic reference audio: 3 seconds of white noise at 24kHz
ref_sr = 24000
ref_duration = 3.0
ref_audio = np.random.randn(int(ref_sr * ref_duration)).astype(np.float32) * 0.1

# --- 4a. Speaker Encoder ---
print("\n  === Speaker Encoder ===")
mels = mel_spectrogram(
    torch.from_numpy(ref_audio).unsqueeze(0),
    n_fft=1024,
    num_mels=128,
    sampling_rate=24000,
    hop_size=256,
    win_size=1024,
    fmin=0,
    fmax=12000,
).transpose(1, 2)
print(f"  Mel input shape:        {list(mels.shape)}")  # [1, T, 128]

with torch.no_grad():
    spk_embed = model.speaker_encoder(mels)
print(f"  Speaker embedding shape: {list(spk_embed.shape)}")  # [1, enc_dim]
x_vector = spk_embed[0]
print(f"  X-vector (per sample):   {list(x_vector.shape)}")  # [enc_dim]

# --- 4b. Speech Tokenizer Encode (ref audio -> codes) ---
print("\n  === Speech Tokenizer Encode ===")
if model.speech_tokenizer is not None:
    st_model = model.speech_tokenizer.model
    ref_wav_tensor = torch.from_numpy(ref_audio).unsqueeze(0).float()  # [1, samples]
    padding_mask = torch.ones_like(ref_wav_tensor, dtype=torch.long)
    with torch.no_grad():
        enc_out = st_model.encode(ref_wav_tensor, padding_mask=padding_mask, return_dict=True)
    ref_codes = enc_out.audio_codes
    print(f"  Ref audio samples:      {ref_wav_tensor.shape[1]}")
    print(f"  Number of code tensors: {len(ref_codes)}")
    for i, c in enumerate(ref_codes):
        print(f"  ref_codes[{i}] shape:     {list(c.shape)}")  # [T, num_quantizers]
else:
    print("  Speech tokenizer not loaded (will skip encode)")
    ref_codes = None

# --- 4c. Talker Prefill Input Construction ---
print("\n  === Talker Input Construction ===")

# Simulate tokenized text: "<|im_start|>assistant\nHello world<|im_end|>\n<|im_start|>assistant\n"
# Using config token IDs
from transformers import AutoTokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    text = "Hello world."
    chat_text = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
    input_ids = tokenizer(chat_text, return_tensors="pt", add_special_tokens=False).input_ids
    print(f"  Text input_ids shape:   {list(input_ids.shape)}")
    print(f"  Text input_ids:         {input_ids[0].tolist()[:20]}...")
except Exception as e:
    print(f"  Tokenizer load failed: {e}")
    # Fallback: synthetic token IDs
    input_ids = torch.tensor([[config.im_start_token_id, 77091, 198, 9707, 1879, 13,
                               config.im_end_token_id, 198,
                               config.im_start_token_id, 77091, 198]])
    print(f"  Fallback input_ids shape: {list(input_ids.shape)}")

# --- 4d. Talker Key Shapes ---
print("\n  === Talker Architecture Key Shapes ===")

talker = model.talker

# Text embedding
text_emb = talker.get_text_embeddings()
print(f"  text_embedding:         {text_emb.weight.shape}")  # [text_vocab, hidden]

# Codec embedding
codec_emb = talker.get_input_embeddings()
print(f"  codec_embedding:        {codec_emb.weight.shape}")  # [codec_vocab, hidden]

# text_projection (ResizeMLP)
print(f"  text_projection type:   {type(talker.text_projection).__name__}")
with torch.no_grad():
    sample_text_hidden = text_emb(torch.tensor([[0]]))
    proj_out = talker.text_projection(sample_text_hidden)
print(f"  text_projection in/out: {list(sample_text_hidden.shape)} -> {list(proj_out.shape)}")

# codec_head
print(f"  codec_head:             Linear({talker.codec_head.in_features} -> {talker.codec_head.out_features})")

# KV cache dimensions
num_layers = tc.num_hidden_layers
num_kv_heads = tc.num_key_value_heads
head_dim = tc.head_dim
print(f"\n  Static KV cache per layer: 2 x [B, {num_kv_heads}, max_seq_len, {head_dim}]")
print(f"  Total KV layers:        {num_layers}")
print(f"  Total KV cache (fp32, B=1, seq=2048): "
      f"{2 * num_layers * num_kv_heads * 2048 * head_dim * 4 / 1e6:.1f} MB")

# --- 4e. Code Predictor Key Shapes ---
print("\n  === Code Predictor Key Shapes ===")
cp = talker.code_predictor

print(f"  small_to_mtp_projection: {type(cp.small_to_mtp_projection).__name__}")
if hasattr(cp.small_to_mtp_projection, 'weight'):
    print(f"    weight shape:         {list(cp.small_to_mtp_projection.weight.shape)}")

print(f"  lm_heads:               {len(cp.lm_head)} heads")
for i, head in enumerate(cp.lm_head):
    print(f"    lm_head[{i}]:           Linear({head.in_features} -> {head.out_features})")

print(f"  codec_embeddings:       {len(cp.model.codec_embedding)} embeddings")
for i, emb in enumerate(cp.model.codec_embedding):
    print(f"    codec_embedding[{i}]:   {emb.weight.shape}")

cp_layers = cpc.num_hidden_layers
cp_kv_heads = cpc.num_key_value_heads
cp_head_dim = cpc.head_dim
print(f"\n  Static KV cache per layer: 2 x [B, {cp_kv_heads}, max_seq_len, {cp_head_dim}]")
print(f"  Total KV layers:        {cp_layers}")

# --- 4f. Speech Tokenizer Decoder Key Shapes ---
print("\n  === Speech Tokenizer Decoder Key Shapes ===")
if model.speech_tokenizer is not None:
    st_dec = model.speech_tokenizer.model.decoder
    print(f"  Decoder type:           {type(st_dec).__name__}")
    print(f"  Total params:           {fmt(count_params(st_dec))}")

    # Test decode with synthetic codes
    # codes shape: [batch, num_quantizers, seq_len]
    test_codes = torch.randint(0, 2048, (1, 16, 10))
    with torch.no_grad():
        test_wav = st_dec(test_codes)
    print(f"  Test input codes:       {list(test_codes.shape)}")
    print(f"  Test output wav:        {list(test_wav.shape)}")
    upsample_factor = test_wav.shape[-1] // test_codes.shape[-1]
    print(f"  Upsample factor:        {upsample_factor}x")

# ── 5. Export Blocker Analysis ───────────────────────────────────────

print("\n[5/5] Export Blocker Analysis")
print("-" * 60)

blockers = []

# Check speaker encoder
print("\n  === Speaker Encoder Export Blockers ===")
se_issues = []
# Conv1d with padding="same" and padding_mode="reflect"
for name, mod in model.speaker_encoder.named_modules():
    if isinstance(mod, nn.Conv1d):
        if hasattr(mod, 'padding') and mod.padding == 'same':
            se_issues.append(f"Conv1d '{name}' uses padding='same' (dynamic pad calc)")
        if hasattr(mod, 'padding_mode') and mod.padding_mode == 'reflect':
            se_issues.append(f"Conv1d '{name}' uses padding_mode='reflect'")

# AttentiveStatisticsPooling dynamic masking
se_issues.append("AttentiveStatisticsPooling: dynamic _length_to_mask(), .repeat(), masked_fill_")
se_issues.append("Res2NetBlock: torch.chunk + for loop (but fixed scale=8, should be OK)")

for issue in se_issues:
    print(f"  [!] {issue}")
blockers.extend([("speaker_encoder", i) for i in se_issues])

# Check talker
print("\n  === Talker Export Blockers ===")
t_issues = []
t_issues.append("MROPE: 3D rotary embedding with sections [24,20,20] β€” need custom handling")
t_issues.append("DynamicCache: must replace with static KV cache tensors")
t_issues.append("create_causal_mask/create_sliding_window_causal_mask from transformers")
t_issues.append("Two embedding tables (text + codec) with interleaving logic")
t_issues.append("code_predictor.generate() called inside forward() β€” autoregressive sub-loop")
t_issues.append("trailing_text_hidden conditional addition in decode step")
t_issues.append("@can_return_tuple decorator")
t_issues.append("@use_kernel_forward_from_hub on RMSNorm")

for issue in t_issues:
    print(f"  [!] {issue}")
blockers.extend([("talker", i) for i in t_issues])

# Check code predictor
print("\n  === Code Predictor Export Blockers ===")
cp_issues = []
cp_issues.append("Uses GenerationMixin.generate() β€” full autoregressive loop")
cp_issues.append("generation_steps counter used to index into lm_head ModuleList")
cp_issues.append("DynamicCache")
cp_issues.append("get_input_embeddings() returns ModuleList (indexed by generation step)")

for issue in cp_issues:
    print(f"  [!] {issue}")
blockers.extend([("code_predictor", i) for i in cp_issues])

# Check speech tokenizer
print("\n  === Speech Tokenizer Export Blockers ===")
st_issues = []
if model.speech_tokenizer is not None:
    st_issues.append("chunked_decode: while loop with dynamic chunk boundaries")
    st_issues.append("ConvTranspose1d with dynamic slicing (right_pad removal)")
    st_issues.append("CausalConv1d: dynamic padding calculation")
    st_issues.append("SnakeBeta: custom activation (should be OK)")
    st_issues.append("SplitResidualVectorQuantizer: F.embedding based (OK)")
    st_issues.append("Transformer decoder with @dynamic_rope_update and torch.autocast")
    st_issues.append("Sliding window attention (window=72)")

for issue in st_issues:
    print(f"  [!] {issue}")
blockers.extend([("speech_tokenizer", i) for i in st_issues])

# ── Summary ──────────────────────────────────────────────────────────

print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)

print(f"""
Model: Qwen3TTSForConditionalGeneration (1.7B-Base)
Total params: {fmt(count_params(model))}

Export Targets (4 modules):
  1. Speaker Encoder       ({fmt(count_params(model.speaker_encoder))} params) β€” ECAPA-TDNN
  2. Talker (Main LM)      ({fmt(count_params(model.talker.model))} + heads) β€” Qwen3 28L
  3. Code Predictor         ({fmt(count_params(model.talker.code_predictor))} params) β€” 5L transformer
  4. Speech Tokenizer Dec   ({fmt(count_params(model.speech_tokenizer.model.decoder)) if model.speech_tokenizer else 'N/A'} params) β€” Transformer + ConvTranspose

Voice Clone Pipeline:
  ref_audio (24kHz)
    -> mel_spectrogram -> [B, T, 128]
    -> speaker_encoder -> x_vector [B, {sec.enc_dim}]

  ref_audio -> speech_tokenizer.encode -> ref_codes [T, 16]

  text -> tokenizer -> input_ids

  [x_vector, ref_codes, input_ids]
    -> talker.generate() -> codec_tokens [T', 16]
    (internally calls code_predictor.generate() per step)

  codec_tokens -> speech_tokenizer.decode -> PCM waveform

Key Dimensions:
  Talker: hidden=2048, layers=28, heads=16, kv_heads=8, head_dim=128
  Code Predictor: hidden=1024, layers=5, heads=16, kv_heads=8
  Codec: vocab=3072 (talker), 2048 (code_predictor), 16 code groups
  Speaker: enc_dim={sec.enc_dim}

Export Strategy:
  Phase 2: Speaker encoder β€” fixed mel length, handle Conv1d padding
  Phase 3: Talker β€” static KV cache, unrolled MROPE, separate prefill/decode
  Phase 4: Code predictor β€” static KV, unroll 15-step generation
  Phase 5: Vocoder (decoder only) β€” fixed code length, handle ConvTranspose1d
  Phase 6: INT8 via torchao int8_weight_only (instant, no calibration)

Total export blockers found: {len(blockers)}
""")

print("Phase 1 analysis complete!")