acul3's picture
Upload scripts/analyze_model.py with huggingface_hub
40ba644 verified
#!/usr/bin/env python3
"""
Phase 1: Deep Architecture Analysis of Qwen3-TTS for ExecuTorch Export
======================================================================
Loads the model, maps all modules with parameter counts, traces a real
voice-clone inference to capture shapes, and identifies export blockers.
"""
import sys
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
# ── paths ────────────────────────────────────────────────────────────
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
# Ensure the venv's site-packages is on the path so qwen_tts can be imported
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
# ── helpers ──────────────────────────────────────────────────────────
def count_params(module: nn.Module) -> int:
return sum(p.numel() for p in module.parameters())
def fmt(n: int) -> str:
if n >= 1e9:
return f"{n / 1e9:.1f}B"
if n >= 1e6:
return f"{n / 1e6:.1f}M"
if n >= 1e3:
return f"{n / 1e3:.1f}K"
return str(n)
def param_table(module: nn.Module, prefix: str = "", depth: int = 0, max_depth: int = 3):
"""Print a hierarchical parameter table."""
total = count_params(module)
indent = " " * depth
name = prefix or module.__class__.__name__
print(f"{indent}{name}: {fmt(total)} params")
if depth < max_depth:
for child_name, child in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
param_table(child, child_prefix, depth + 1, max_depth)
# ── 1. Load Model ───────────────────────────────────────────────────
print("=" * 70)
print("PHASE 1: Deep Architecture Analysis β€” Qwen3-TTS 1.7B-Base")
print("=" * 70)
print("\n[1/5] Loading model from", MODEL_PATH)
t0 = time.time()
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import (
Qwen3TTSForConditionalGeneration,
mel_spectrogram,
)
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
# Force SDPA attention for exportability
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH,
config=config,
torch_dtype=torch.float32,
attn_implementation="sdpa",
device_map="cpu",
)
model.eval()
print(f" Loaded in {time.time() - t0:.1f}s")
# ── 2. Parameter Map ────────────────────────────────────────────────
print("\n[2/5] Parameter Map (hierarchical)")
print("-" * 60)
param_table(model, "Qwen3TTSForConditionalGeneration", max_depth=4)
print("\n--- Top-level component sizes ---")
components = {
"speaker_encoder": model.speaker_encoder,
"talker": model.talker,
"talker.model": model.talker.model,
"talker.text_projection": model.talker.text_projection,
"talker.codec_head": model.talker.codec_head,
"talker.code_predictor": model.talker.code_predictor,
}
for name, mod in components.items():
print(f" {name:40s}: {fmt(count_params(mod)):>8s} params")
if model.speech_tokenizer is not None and hasattr(model.speech_tokenizer, 'model'):
st = model.speech_tokenizer.model # Qwen3TTSTokenizerV2Model (nn.Module)
print(f" {'speech_tokenizer.model':40s}: {fmt(count_params(st)):>8s} params")
if hasattr(st, 'encoder'):
print(f" {'speech_tokenizer.model.encoder':40s}: {fmt(count_params(st.encoder)):>8s} params")
if hasattr(st, 'decoder'):
print(f" {'speech_tokenizer.model.decoder':40s}: {fmt(count_params(st.decoder)):>8s} params")
# ── 3. Config Summary ───────────────────────────────────────────────
print("\n[3/5] Key Config Values")
print("-" * 60)
tc = config.talker_config
cpc = tc.code_predictor_config
sec = config.speaker_encoder_config
info = {
"Speaker Encoder": {
"mel_dim": sec.mel_dim,
"enc_dim (output)": sec.enc_dim,
"enc_channels": sec.enc_channels,
"sample_rate": sec.sample_rate,
},
"Talker (Main LM)": {
"hidden_size": tc.hidden_size,
"num_hidden_layers": tc.num_hidden_layers,
"num_attention_heads": tc.num_attention_heads,
"num_key_value_heads": tc.num_key_value_heads,
"head_dim": tc.head_dim,
"intermediate_size": tc.intermediate_size,
"text_vocab_size": tc.text_vocab_size,
"codec_vocab_size": tc.vocab_size,
"num_code_groups": tc.num_code_groups,
"max_position_embeddings": tc.max_position_embeddings,
"rope_scaling": tc.rope_scaling,
},
"Code Predictor": {
"hidden_size": cpc.hidden_size,
"num_hidden_layers": cpc.num_hidden_layers,
"num_attention_heads": cpc.num_attention_heads,
"num_key_value_heads": cpc.num_key_value_heads,
"num_code_groups": cpc.num_code_groups,
"vocab_size": cpc.vocab_size,
},
}
for section, kvs in info.items():
print(f"\n {section}:")
for k, v in kvs.items():
print(f" {k:35s}: {v}")
# ── 4. Trace Real Inference ─────────────────────────────────────────
print("\n[4/5] Tracing Real Voice-Clone Inference")
print("-" * 60)
# Create synthetic reference audio: 3 seconds of white noise at 24kHz
ref_sr = 24000
ref_duration = 3.0
ref_audio = np.random.randn(int(ref_sr * ref_duration)).astype(np.float32) * 0.1
# --- 4a. Speaker Encoder ---
print("\n === Speaker Encoder ===")
mels = mel_spectrogram(
torch.from_numpy(ref_audio).unsqueeze(0),
n_fft=1024,
num_mels=128,
sampling_rate=24000,
hop_size=256,
win_size=1024,
fmin=0,
fmax=12000,
).transpose(1, 2)
print(f" Mel input shape: {list(mels.shape)}") # [1, T, 128]
with torch.no_grad():
spk_embed = model.speaker_encoder(mels)
print(f" Speaker embedding shape: {list(spk_embed.shape)}") # [1, enc_dim]
x_vector = spk_embed[0]
print(f" X-vector (per sample): {list(x_vector.shape)}") # [enc_dim]
# --- 4b. Speech Tokenizer Encode (ref audio -> codes) ---
print("\n === Speech Tokenizer Encode ===")
if model.speech_tokenizer is not None:
st_model = model.speech_tokenizer.model
ref_wav_tensor = torch.from_numpy(ref_audio).unsqueeze(0).float() # [1, samples]
padding_mask = torch.ones_like(ref_wav_tensor, dtype=torch.long)
with torch.no_grad():
enc_out = st_model.encode(ref_wav_tensor, padding_mask=padding_mask, return_dict=True)
ref_codes = enc_out.audio_codes
print(f" Ref audio samples: {ref_wav_tensor.shape[1]}")
print(f" Number of code tensors: {len(ref_codes)}")
for i, c in enumerate(ref_codes):
print(f" ref_codes[{i}] shape: {list(c.shape)}") # [T, num_quantizers]
else:
print(" Speech tokenizer not loaded (will skip encode)")
ref_codes = None
# --- 4c. Talker Prefill Input Construction ---
print("\n === Talker Input Construction ===")
# Simulate tokenized text: "<|im_start|>assistant\nHello world<|im_end|>\n<|im_start|>assistant\n"
# Using config token IDs
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
text = "Hello world."
chat_text = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer(chat_text, return_tensors="pt", add_special_tokens=False).input_ids
print(f" Text input_ids shape: {list(input_ids.shape)}")
print(f" Text input_ids: {input_ids[0].tolist()[:20]}...")
except Exception as e:
print(f" Tokenizer load failed: {e}")
# Fallback: synthetic token IDs
input_ids = torch.tensor([[config.im_start_token_id, 77091, 198, 9707, 1879, 13,
config.im_end_token_id, 198,
config.im_start_token_id, 77091, 198]])
print(f" Fallback input_ids shape: {list(input_ids.shape)}")
# --- 4d. Talker Key Shapes ---
print("\n === Talker Architecture Key Shapes ===")
talker = model.talker
# Text embedding
text_emb = talker.get_text_embeddings()
print(f" text_embedding: {text_emb.weight.shape}") # [text_vocab, hidden]
# Codec embedding
codec_emb = talker.get_input_embeddings()
print(f" codec_embedding: {codec_emb.weight.shape}") # [codec_vocab, hidden]
# text_projection (ResizeMLP)
print(f" text_projection type: {type(talker.text_projection).__name__}")
with torch.no_grad():
sample_text_hidden = text_emb(torch.tensor([[0]]))
proj_out = talker.text_projection(sample_text_hidden)
print(f" text_projection in/out: {list(sample_text_hidden.shape)} -> {list(proj_out.shape)}")
# codec_head
print(f" codec_head: Linear({talker.codec_head.in_features} -> {talker.codec_head.out_features})")
# KV cache dimensions
num_layers = tc.num_hidden_layers
num_kv_heads = tc.num_key_value_heads
head_dim = tc.head_dim
print(f"\n Static KV cache per layer: 2 x [B, {num_kv_heads}, max_seq_len, {head_dim}]")
print(f" Total KV layers: {num_layers}")
print(f" Total KV cache (fp32, B=1, seq=2048): "
f"{2 * num_layers * num_kv_heads * 2048 * head_dim * 4 / 1e6:.1f} MB")
# --- 4e. Code Predictor Key Shapes ---
print("\n === Code Predictor Key Shapes ===")
cp = talker.code_predictor
print(f" small_to_mtp_projection: {type(cp.small_to_mtp_projection).__name__}")
if hasattr(cp.small_to_mtp_projection, 'weight'):
print(f" weight shape: {list(cp.small_to_mtp_projection.weight.shape)}")
print(f" lm_heads: {len(cp.lm_head)} heads")
for i, head in enumerate(cp.lm_head):
print(f" lm_head[{i}]: Linear({head.in_features} -> {head.out_features})")
print(f" codec_embeddings: {len(cp.model.codec_embedding)} embeddings")
for i, emb in enumerate(cp.model.codec_embedding):
print(f" codec_embedding[{i}]: {emb.weight.shape}")
cp_layers = cpc.num_hidden_layers
cp_kv_heads = cpc.num_key_value_heads
cp_head_dim = cpc.head_dim
print(f"\n Static KV cache per layer: 2 x [B, {cp_kv_heads}, max_seq_len, {cp_head_dim}]")
print(f" Total KV layers: {cp_layers}")
# --- 4f. Speech Tokenizer Decoder Key Shapes ---
print("\n === Speech Tokenizer Decoder Key Shapes ===")
if model.speech_tokenizer is not None:
st_dec = model.speech_tokenizer.model.decoder
print(f" Decoder type: {type(st_dec).__name__}")
print(f" Total params: {fmt(count_params(st_dec))}")
# Test decode with synthetic codes
# codes shape: [batch, num_quantizers, seq_len]
test_codes = torch.randint(0, 2048, (1, 16, 10))
with torch.no_grad():
test_wav = st_dec(test_codes)
print(f" Test input codes: {list(test_codes.shape)}")
print(f" Test output wav: {list(test_wav.shape)}")
upsample_factor = test_wav.shape[-1] // test_codes.shape[-1]
print(f" Upsample factor: {upsample_factor}x")
# ── 5. Export Blocker Analysis ───────────────────────────────────────
print("\n[5/5] Export Blocker Analysis")
print("-" * 60)
blockers = []
# Check speaker encoder
print("\n === Speaker Encoder Export Blockers ===")
se_issues = []
# Conv1d with padding="same" and padding_mode="reflect"
for name, mod in model.speaker_encoder.named_modules():
if isinstance(mod, nn.Conv1d):
if hasattr(mod, 'padding') and mod.padding == 'same':
se_issues.append(f"Conv1d '{name}' uses padding='same' (dynamic pad calc)")
if hasattr(mod, 'padding_mode') and mod.padding_mode == 'reflect':
se_issues.append(f"Conv1d '{name}' uses padding_mode='reflect'")
# AttentiveStatisticsPooling dynamic masking
se_issues.append("AttentiveStatisticsPooling: dynamic _length_to_mask(), .repeat(), masked_fill_")
se_issues.append("Res2NetBlock: torch.chunk + for loop (but fixed scale=8, should be OK)")
for issue in se_issues:
print(f" [!] {issue}")
blockers.extend([("speaker_encoder", i) for i in se_issues])
# Check talker
print("\n === Talker Export Blockers ===")
t_issues = []
t_issues.append("MROPE: 3D rotary embedding with sections [24,20,20] β€” need custom handling")
t_issues.append("DynamicCache: must replace with static KV cache tensors")
t_issues.append("create_causal_mask/create_sliding_window_causal_mask from transformers")
t_issues.append("Two embedding tables (text + codec) with interleaving logic")
t_issues.append("code_predictor.generate() called inside forward() β€” autoregressive sub-loop")
t_issues.append("trailing_text_hidden conditional addition in decode step")
t_issues.append("@can_return_tuple decorator")
t_issues.append("@use_kernel_forward_from_hub on RMSNorm")
for issue in t_issues:
print(f" [!] {issue}")
blockers.extend([("talker", i) for i in t_issues])
# Check code predictor
print("\n === Code Predictor Export Blockers ===")
cp_issues = []
cp_issues.append("Uses GenerationMixin.generate() β€” full autoregressive loop")
cp_issues.append("generation_steps counter used to index into lm_head ModuleList")
cp_issues.append("DynamicCache")
cp_issues.append("get_input_embeddings() returns ModuleList (indexed by generation step)")
for issue in cp_issues:
print(f" [!] {issue}")
blockers.extend([("code_predictor", i) for i in cp_issues])
# Check speech tokenizer
print("\n === Speech Tokenizer Export Blockers ===")
st_issues = []
if model.speech_tokenizer is not None:
st_issues.append("chunked_decode: while loop with dynamic chunk boundaries")
st_issues.append("ConvTranspose1d with dynamic slicing (right_pad removal)")
st_issues.append("CausalConv1d: dynamic padding calculation")
st_issues.append("SnakeBeta: custom activation (should be OK)")
st_issues.append("SplitResidualVectorQuantizer: F.embedding based (OK)")
st_issues.append("Transformer decoder with @dynamic_rope_update and torch.autocast")
st_issues.append("Sliding window attention (window=72)")
for issue in st_issues:
print(f" [!] {issue}")
blockers.extend([("speech_tokenizer", i) for i in st_issues])
# ── Summary ──────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"""
Model: Qwen3TTSForConditionalGeneration (1.7B-Base)
Total params: {fmt(count_params(model))}
Export Targets (4 modules):
1. Speaker Encoder ({fmt(count_params(model.speaker_encoder))} params) β€” ECAPA-TDNN
2. Talker (Main LM) ({fmt(count_params(model.talker.model))} + heads) β€” Qwen3 28L
3. Code Predictor ({fmt(count_params(model.talker.code_predictor))} params) β€” 5L transformer
4. Speech Tokenizer Dec ({fmt(count_params(model.speech_tokenizer.model.decoder)) if model.speech_tokenizer else 'N/A'} params) β€” Transformer + ConvTranspose
Voice Clone Pipeline:
ref_audio (24kHz)
-> mel_spectrogram -> [B, T, 128]
-> speaker_encoder -> x_vector [B, {sec.enc_dim}]
ref_audio -> speech_tokenizer.encode -> ref_codes [T, 16]
text -> tokenizer -> input_ids
[x_vector, ref_codes, input_ids]
-> talker.generate() -> codec_tokens [T', 16]
(internally calls code_predictor.generate() per step)
codec_tokens -> speech_tokenizer.decode -> PCM waveform
Key Dimensions:
Talker: hidden=2048, layers=28, heads=16, kv_heads=8, head_dim=128
Code Predictor: hidden=1024, layers=5, heads=16, kv_heads=8
Codec: vocab=3072 (talker), 2048 (code_predictor), 16 code groups
Speaker: enc_dim={sec.enc_dim}
Export Strategy:
Phase 2: Speaker encoder β€” fixed mel length, handle Conv1d padding
Phase 3: Talker β€” static KV cache, unrolled MROPE, separate prefill/decode
Phase 4: Code predictor β€” static KV, unroll 15-step generation
Phase 5: Vocoder (decoder only) β€” fixed code length, handle ConvTranspose1d
Phase 6: INT8 via torchao int8_weight_only (instant, no calibration)
Total export blockers found: {len(blockers)}
""")
print("Phase 1 analysis complete!")