acul3 commited on
Commit
40ba644
Β·
verified Β·
1 Parent(s): daf4ae0

Upload scripts/analyze_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/analyze_model.py +408 -0
scripts/analyze_model.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Phase 1: Deep Architecture Analysis of Qwen3-TTS for ExecuTorch Export
4
+ ======================================================================
5
+ Loads the model, maps all modules with parameter counts, traces a real
6
+ voice-clone inference to capture shapes, and identifies export blockers.
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ import time
12
+ import json
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ # ── paths ────────────────────────────────────────────────────────────
18
+ MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
19
+ VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
20
+ QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
21
+
22
+ # Ensure the venv's site-packages is on the path so qwen_tts can be imported
23
+ if VENV_SITE not in sys.path:
24
+ sys.path.insert(0, VENV_SITE)
25
+ if QWEN_TTS_SRC not in sys.path:
26
+ sys.path.insert(0, QWEN_TTS_SRC)
27
+
28
+ # ── helpers ──────────────────────────────────────────────────────────
29
+
30
+ def count_params(module: nn.Module) -> int:
31
+ return sum(p.numel() for p in module.parameters())
32
+
33
+ def fmt(n: int) -> str:
34
+ if n >= 1e9:
35
+ return f"{n / 1e9:.1f}B"
36
+ if n >= 1e6:
37
+ return f"{n / 1e6:.1f}M"
38
+ if n >= 1e3:
39
+ return f"{n / 1e3:.1f}K"
40
+ return str(n)
41
+
42
+ def param_table(module: nn.Module, prefix: str = "", depth: int = 0, max_depth: int = 3):
43
+ """Print a hierarchical parameter table."""
44
+ total = count_params(module)
45
+ indent = " " * depth
46
+ name = prefix or module.__class__.__name__
47
+ print(f"{indent}{name}: {fmt(total)} params")
48
+ if depth < max_depth:
49
+ for child_name, child in module.named_children():
50
+ child_prefix = f"{prefix}.{child_name}" if prefix else child_name
51
+ param_table(child, child_prefix, depth + 1, max_depth)
52
+
53
+
54
+ # ── 1. Load Model ───────────────────────────────────────────────────
55
+
56
+ print("=" * 70)
57
+ print("PHASE 1: Deep Architecture Analysis β€” Qwen3-TTS 1.7B-Base")
58
+ print("=" * 70)
59
+
60
+ print("\n[1/5] Loading model from", MODEL_PATH)
61
+ t0 = time.time()
62
+
63
+ from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
64
+ from qwen_tts.core.models.modeling_qwen3_tts import (
65
+ Qwen3TTSForConditionalGeneration,
66
+ mel_spectrogram,
67
+ )
68
+
69
+ config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
70
+ # Force SDPA attention for exportability
71
+ model = Qwen3TTSForConditionalGeneration.from_pretrained(
72
+ MODEL_PATH,
73
+ config=config,
74
+ torch_dtype=torch.float32,
75
+ attn_implementation="sdpa",
76
+ device_map="cpu",
77
+ )
78
+ model.eval()
79
+ print(f" Loaded in {time.time() - t0:.1f}s")
80
+
81
+ # ── 2. Parameter Map ────────────────────────────────────────────────
82
+
83
+ print("\n[2/5] Parameter Map (hierarchical)")
84
+ print("-" * 60)
85
+
86
+ param_table(model, "Qwen3TTSForConditionalGeneration", max_depth=4)
87
+
88
+ print("\n--- Top-level component sizes ---")
89
+ components = {
90
+ "speaker_encoder": model.speaker_encoder,
91
+ "talker": model.talker,
92
+ "talker.model": model.talker.model,
93
+ "talker.text_projection": model.talker.text_projection,
94
+ "talker.codec_head": model.talker.codec_head,
95
+ "talker.code_predictor": model.talker.code_predictor,
96
+ }
97
+ for name, mod in components.items():
98
+ print(f" {name:40s}: {fmt(count_params(mod)):>8s} params")
99
+
100
+ if model.speech_tokenizer is not None and hasattr(model.speech_tokenizer, 'model'):
101
+ st = model.speech_tokenizer.model # Qwen3TTSTokenizerV2Model (nn.Module)
102
+ print(f" {'speech_tokenizer.model':40s}: {fmt(count_params(st)):>8s} params")
103
+ if hasattr(st, 'encoder'):
104
+ print(f" {'speech_tokenizer.model.encoder':40s}: {fmt(count_params(st.encoder)):>8s} params")
105
+ if hasattr(st, 'decoder'):
106
+ print(f" {'speech_tokenizer.model.decoder':40s}: {fmt(count_params(st.decoder)):>8s} params")
107
+
108
+ # ── 3. Config Summary ───────────────────────────────────────────────
109
+
110
+ print("\n[3/5] Key Config Values")
111
+ print("-" * 60)
112
+
113
+ tc = config.talker_config
114
+ cpc = tc.code_predictor_config
115
+ sec = config.speaker_encoder_config
116
+
117
+ info = {
118
+ "Speaker Encoder": {
119
+ "mel_dim": sec.mel_dim,
120
+ "enc_dim (output)": sec.enc_dim,
121
+ "enc_channels": sec.enc_channels,
122
+ "sample_rate": sec.sample_rate,
123
+ },
124
+ "Talker (Main LM)": {
125
+ "hidden_size": tc.hidden_size,
126
+ "num_hidden_layers": tc.num_hidden_layers,
127
+ "num_attention_heads": tc.num_attention_heads,
128
+ "num_key_value_heads": tc.num_key_value_heads,
129
+ "head_dim": tc.head_dim,
130
+ "intermediate_size": tc.intermediate_size,
131
+ "text_vocab_size": tc.text_vocab_size,
132
+ "codec_vocab_size": tc.vocab_size,
133
+ "num_code_groups": tc.num_code_groups,
134
+ "max_position_embeddings": tc.max_position_embeddings,
135
+ "rope_scaling": tc.rope_scaling,
136
+ },
137
+ "Code Predictor": {
138
+ "hidden_size": cpc.hidden_size,
139
+ "num_hidden_layers": cpc.num_hidden_layers,
140
+ "num_attention_heads": cpc.num_attention_heads,
141
+ "num_key_value_heads": cpc.num_key_value_heads,
142
+ "num_code_groups": cpc.num_code_groups,
143
+ "vocab_size": cpc.vocab_size,
144
+ },
145
+ }
146
+
147
+ for section, kvs in info.items():
148
+ print(f"\n {section}:")
149
+ for k, v in kvs.items():
150
+ print(f" {k:35s}: {v}")
151
+
152
+ # ── 4. Trace Real Inference ─────────────────────────────────────────
153
+
154
+ print("\n[4/5] Tracing Real Voice-Clone Inference")
155
+ print("-" * 60)
156
+
157
+ # Create synthetic reference audio: 3 seconds of white noise at 24kHz
158
+ ref_sr = 24000
159
+ ref_duration = 3.0
160
+ ref_audio = np.random.randn(int(ref_sr * ref_duration)).astype(np.float32) * 0.1
161
+
162
+ # --- 4a. Speaker Encoder ---
163
+ print("\n === Speaker Encoder ===")
164
+ mels = mel_spectrogram(
165
+ torch.from_numpy(ref_audio).unsqueeze(0),
166
+ n_fft=1024,
167
+ num_mels=128,
168
+ sampling_rate=24000,
169
+ hop_size=256,
170
+ win_size=1024,
171
+ fmin=0,
172
+ fmax=12000,
173
+ ).transpose(1, 2)
174
+ print(f" Mel input shape: {list(mels.shape)}") # [1, T, 128]
175
+
176
+ with torch.no_grad():
177
+ spk_embed = model.speaker_encoder(mels)
178
+ print(f" Speaker embedding shape: {list(spk_embed.shape)}") # [1, enc_dim]
179
+ x_vector = spk_embed[0]
180
+ print(f" X-vector (per sample): {list(x_vector.shape)}") # [enc_dim]
181
+
182
+ # --- 4b. Speech Tokenizer Encode (ref audio -> codes) ---
183
+ print("\n === Speech Tokenizer Encode ===")
184
+ if model.speech_tokenizer is not None:
185
+ st_model = model.speech_tokenizer.model
186
+ ref_wav_tensor = torch.from_numpy(ref_audio).unsqueeze(0).float() # [1, samples]
187
+ padding_mask = torch.ones_like(ref_wav_tensor, dtype=torch.long)
188
+ with torch.no_grad():
189
+ enc_out = st_model.encode(ref_wav_tensor, padding_mask=padding_mask, return_dict=True)
190
+ ref_codes = enc_out.audio_codes
191
+ print(f" Ref audio samples: {ref_wav_tensor.shape[1]}")
192
+ print(f" Number of code tensors: {len(ref_codes)}")
193
+ for i, c in enumerate(ref_codes):
194
+ print(f" ref_codes[{i}] shape: {list(c.shape)}") # [T, num_quantizers]
195
+ else:
196
+ print(" Speech tokenizer not loaded (will skip encode)")
197
+ ref_codes = None
198
+
199
+ # --- 4c. Talker Prefill Input Construction ---
200
+ print("\n === Talker Input Construction ===")
201
+
202
+ # Simulate tokenized text: "<|im_start|>assistant\nHello world<|im_end|>\n<|im_start|>assistant\n"
203
+ # Using config token IDs
204
+ from transformers import AutoTokenizer
205
+ try:
206
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
207
+ text = "Hello world."
208
+ chat_text = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
209
+ input_ids = tokenizer(chat_text, return_tensors="pt", add_special_tokens=False).input_ids
210
+ print(f" Text input_ids shape: {list(input_ids.shape)}")
211
+ print(f" Text input_ids: {input_ids[0].tolist()[:20]}...")
212
+ except Exception as e:
213
+ print(f" Tokenizer load failed: {e}")
214
+ # Fallback: synthetic token IDs
215
+ input_ids = torch.tensor([[config.im_start_token_id, 77091, 198, 9707, 1879, 13,
216
+ config.im_end_token_id, 198,
217
+ config.im_start_token_id, 77091, 198]])
218
+ print(f" Fallback input_ids shape: {list(input_ids.shape)}")
219
+
220
+ # --- 4d. Talker Key Shapes ---
221
+ print("\n === Talker Architecture Key Shapes ===")
222
+
223
+ talker = model.talker
224
+
225
+ # Text embedding
226
+ text_emb = talker.get_text_embeddings()
227
+ print(f" text_embedding: {text_emb.weight.shape}") # [text_vocab, hidden]
228
+
229
+ # Codec embedding
230
+ codec_emb = talker.get_input_embeddings()
231
+ print(f" codec_embedding: {codec_emb.weight.shape}") # [codec_vocab, hidden]
232
+
233
+ # text_projection (ResizeMLP)
234
+ print(f" text_projection type: {type(talker.text_projection).__name__}")
235
+ with torch.no_grad():
236
+ sample_text_hidden = text_emb(torch.tensor([[0]]))
237
+ proj_out = talker.text_projection(sample_text_hidden)
238
+ print(f" text_projection in/out: {list(sample_text_hidden.shape)} -> {list(proj_out.shape)}")
239
+
240
+ # codec_head
241
+ print(f" codec_head: Linear({talker.codec_head.in_features} -> {talker.codec_head.out_features})")
242
+
243
+ # KV cache dimensions
244
+ num_layers = tc.num_hidden_layers
245
+ num_kv_heads = tc.num_key_value_heads
246
+ head_dim = tc.head_dim
247
+ print(f"\n Static KV cache per layer: 2 x [B, {num_kv_heads}, max_seq_len, {head_dim}]")
248
+ print(f" Total KV layers: {num_layers}")
249
+ print(f" Total KV cache (fp32, B=1, seq=2048): "
250
+ f"{2 * num_layers * num_kv_heads * 2048 * head_dim * 4 / 1e6:.1f} MB")
251
+
252
+ # --- 4e. Code Predictor Key Shapes ---
253
+ print("\n === Code Predictor Key Shapes ===")
254
+ cp = talker.code_predictor
255
+
256
+ print(f" small_to_mtp_projection: {type(cp.small_to_mtp_projection).__name__}")
257
+ if hasattr(cp.small_to_mtp_projection, 'weight'):
258
+ print(f" weight shape: {list(cp.small_to_mtp_projection.weight.shape)}")
259
+
260
+ print(f" lm_heads: {len(cp.lm_head)} heads")
261
+ for i, head in enumerate(cp.lm_head):
262
+ print(f" lm_head[{i}]: Linear({head.in_features} -> {head.out_features})")
263
+
264
+ print(f" codec_embeddings: {len(cp.model.codec_embedding)} embeddings")
265
+ for i, emb in enumerate(cp.model.codec_embedding):
266
+ print(f" codec_embedding[{i}]: {emb.weight.shape}")
267
+
268
+ cp_layers = cpc.num_hidden_layers
269
+ cp_kv_heads = cpc.num_key_value_heads
270
+ cp_head_dim = cpc.head_dim
271
+ print(f"\n Static KV cache per layer: 2 x [B, {cp_kv_heads}, max_seq_len, {cp_head_dim}]")
272
+ print(f" Total KV layers: {cp_layers}")
273
+
274
+ # --- 4f. Speech Tokenizer Decoder Key Shapes ---
275
+ print("\n === Speech Tokenizer Decoder Key Shapes ===")
276
+ if model.speech_tokenizer is not None:
277
+ st_dec = model.speech_tokenizer.model.decoder
278
+ print(f" Decoder type: {type(st_dec).__name__}")
279
+ print(f" Total params: {fmt(count_params(st_dec))}")
280
+
281
+ # Test decode with synthetic codes
282
+ # codes shape: [batch, num_quantizers, seq_len]
283
+ test_codes = torch.randint(0, 2048, (1, 16, 10))
284
+ with torch.no_grad():
285
+ test_wav = st_dec(test_codes)
286
+ print(f" Test input codes: {list(test_codes.shape)}")
287
+ print(f" Test output wav: {list(test_wav.shape)}")
288
+ upsample_factor = test_wav.shape[-1] // test_codes.shape[-1]
289
+ print(f" Upsample factor: {upsample_factor}x")
290
+
291
+ # ── 5. Export Blocker Analysis ───────────────────────────────────────
292
+
293
+ print("\n[5/5] Export Blocker Analysis")
294
+ print("-" * 60)
295
+
296
+ blockers = []
297
+
298
+ # Check speaker encoder
299
+ print("\n === Speaker Encoder Export Blockers ===")
300
+ se_issues = []
301
+ # Conv1d with padding="same" and padding_mode="reflect"
302
+ for name, mod in model.speaker_encoder.named_modules():
303
+ if isinstance(mod, nn.Conv1d):
304
+ if hasattr(mod, 'padding') and mod.padding == 'same':
305
+ se_issues.append(f"Conv1d '{name}' uses padding='same' (dynamic pad calc)")
306
+ if hasattr(mod, 'padding_mode') and mod.padding_mode == 'reflect':
307
+ se_issues.append(f"Conv1d '{name}' uses padding_mode='reflect'")
308
+
309
+ # AttentiveStatisticsPooling dynamic masking
310
+ se_issues.append("AttentiveStatisticsPooling: dynamic _length_to_mask(), .repeat(), masked_fill_")
311
+ se_issues.append("Res2NetBlock: torch.chunk + for loop (but fixed scale=8, should be OK)")
312
+
313
+ for issue in se_issues:
314
+ print(f" [!] {issue}")
315
+ blockers.extend([("speaker_encoder", i) for i in se_issues])
316
+
317
+ # Check talker
318
+ print("\n === Talker Export Blockers ===")
319
+ t_issues = []
320
+ t_issues.append("MROPE: 3D rotary embedding with sections [24,20,20] β€” need custom handling")
321
+ t_issues.append("DynamicCache: must replace with static KV cache tensors")
322
+ t_issues.append("create_causal_mask/create_sliding_window_causal_mask from transformers")
323
+ t_issues.append("Two embedding tables (text + codec) with interleaving logic")
324
+ t_issues.append("code_predictor.generate() called inside forward() β€” autoregressive sub-loop")
325
+ t_issues.append("trailing_text_hidden conditional addition in decode step")
326
+ t_issues.append("@can_return_tuple decorator")
327
+ t_issues.append("@use_kernel_forward_from_hub on RMSNorm")
328
+
329
+ for issue in t_issues:
330
+ print(f" [!] {issue}")
331
+ blockers.extend([("talker", i) for i in t_issues])
332
+
333
+ # Check code predictor
334
+ print("\n === Code Predictor Export Blockers ===")
335
+ cp_issues = []
336
+ cp_issues.append("Uses GenerationMixin.generate() β€” full autoregressive loop")
337
+ cp_issues.append("generation_steps counter used to index into lm_head ModuleList")
338
+ cp_issues.append("DynamicCache")
339
+ cp_issues.append("get_input_embeddings() returns ModuleList (indexed by generation step)")
340
+
341
+ for issue in cp_issues:
342
+ print(f" [!] {issue}")
343
+ blockers.extend([("code_predictor", i) for i in cp_issues])
344
+
345
+ # Check speech tokenizer
346
+ print("\n === Speech Tokenizer Export Blockers ===")
347
+ st_issues = []
348
+ if model.speech_tokenizer is not None:
349
+ st_issues.append("chunked_decode: while loop with dynamic chunk boundaries")
350
+ st_issues.append("ConvTranspose1d with dynamic slicing (right_pad removal)")
351
+ st_issues.append("CausalConv1d: dynamic padding calculation")
352
+ st_issues.append("SnakeBeta: custom activation (should be OK)")
353
+ st_issues.append("SplitResidualVectorQuantizer: F.embedding based (OK)")
354
+ st_issues.append("Transformer decoder with @dynamic_rope_update and torch.autocast")
355
+ st_issues.append("Sliding window attention (window=72)")
356
+
357
+ for issue in st_issues:
358
+ print(f" [!] {issue}")
359
+ blockers.extend([("speech_tokenizer", i) for i in st_issues])
360
+
361
+ # ── Summary ───────────────────────────────────────────────────────��──
362
+
363
+ print("\n" + "=" * 70)
364
+ print("SUMMARY")
365
+ print("=" * 70)
366
+
367
+ print(f"""
368
+ Model: Qwen3TTSForConditionalGeneration (1.7B-Base)
369
+ Total params: {fmt(count_params(model))}
370
+
371
+ Export Targets (4 modules):
372
+ 1. Speaker Encoder ({fmt(count_params(model.speaker_encoder))} params) β€” ECAPA-TDNN
373
+ 2. Talker (Main LM) ({fmt(count_params(model.talker.model))} + heads) β€” Qwen3 28L
374
+ 3. Code Predictor ({fmt(count_params(model.talker.code_predictor))} params) β€” 5L transformer
375
+ 4. Speech Tokenizer Dec ({fmt(count_params(model.speech_tokenizer.model.decoder)) if model.speech_tokenizer else 'N/A'} params) β€” Transformer + ConvTranspose
376
+
377
+ Voice Clone Pipeline:
378
+ ref_audio (24kHz)
379
+ -> mel_spectrogram -> [B, T, 128]
380
+ -> speaker_encoder -> x_vector [B, {sec.enc_dim}]
381
+
382
+ ref_audio -> speech_tokenizer.encode -> ref_codes [T, 16]
383
+
384
+ text -> tokenizer -> input_ids
385
+
386
+ [x_vector, ref_codes, input_ids]
387
+ -> talker.generate() -> codec_tokens [T', 16]
388
+ (internally calls code_predictor.generate() per step)
389
+
390
+ codec_tokens -> speech_tokenizer.decode -> PCM waveform
391
+
392
+ Key Dimensions:
393
+ Talker: hidden=2048, layers=28, heads=16, kv_heads=8, head_dim=128
394
+ Code Predictor: hidden=1024, layers=5, heads=16, kv_heads=8
395
+ Codec: vocab=3072 (talker), 2048 (code_predictor), 16 code groups
396
+ Speaker: enc_dim={sec.enc_dim}
397
+
398
+ Export Strategy:
399
+ Phase 2: Speaker encoder β€” fixed mel length, handle Conv1d padding
400
+ Phase 3: Talker β€” static KV cache, unrolled MROPE, separate prefill/decode
401
+ Phase 4: Code predictor β€” static KV, unroll 15-step generation
402
+ Phase 5: Vocoder (decoder only) β€” fixed code length, handle ConvTranspose1d
403
+ Phase 6: INT8 via torchao int8_weight_only (instant, no calibration)
404
+
405
+ Total export blockers found: {len(blockers)}
406
+ """)
407
+
408
+ print("Phase 1 analysis complete!")