acul3 commited on
Commit
ed7432b
Β·
verified Β·
1 Parent(s): 9611639

Upload scripts/validate_tokens_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/validate_tokens_v2.py +232 -0
scripts/validate_tokens_v2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Token-by-token validation v2: Build TalkerForExport inline (no import of export script).
4
+ Compares generated tokens: Original HF talker vs Fixed wrapper (same as .pte source).
5
+ Runs on CPU, greedy decoding, 10 steps.
6
+ """
7
+
8
+ import sys, os, time, copy, torch, torch.nn as nn, torch.nn.functional as F
9
+
10
+ sys.path.insert(0, os.path.expanduser("~/Documents/Qwen3-TTS"))
11
+
12
+ MAX_SEQ_LEN = 2048
13
+ NUM_LAYERS = 28
14
+ NUM_HEADS = 16
15
+ NUM_KV_HEADS = 8
16
+ HEAD_DIM = 128
17
+ HIDDEN_SIZE = 2048
18
+ KV_GROUPS = NUM_HEADS // NUM_KV_HEADS
19
+ NUM_STEPS = 10
20
+
21
+
22
+ class RMSNorm(nn.Module):
23
+ def __init__(self, dim, eps=1e-6):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.ones(dim))
26
+ self.eps = eps
27
+ def forward(self, x):
28
+ return (self.weight * (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps))).to(x.dtype)
29
+
30
+
31
+ def rotate_half(x):
32
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
33
+ return torch.cat((-x2, x1), dim=-1)
34
+
35
+
36
+ class FixedAttn(nn.Module):
37
+ def __init__(self, orig):
38
+ super().__init__()
39
+ self.q_proj = copy.deepcopy(orig.q_proj)
40
+ self.k_proj = copy.deepcopy(orig.k_proj)
41
+ self.v_proj = copy.deepcopy(orig.v_proj)
42
+ self.o_proj = copy.deepcopy(orig.o_proj)
43
+ self.q_norm = RMSNorm(HEAD_DIM); self.q_norm.weight = nn.Parameter(orig.q_norm.weight.clone())
44
+ self.k_norm = RMSNorm(HEAD_DIM); self.k_norm.weight = nn.Parameter(orig.k_norm.weight.clone())
45
+ self.scale = HEAD_DIM ** -0.5
46
+
47
+ def forward(self, h, cos, sin, cp, kc, vc, am):
48
+ B, S, _ = h.shape
49
+ q = self.q_norm(self.q_proj(h).view(B, S, NUM_HEADS, HEAD_DIM)).transpose(1, 2)
50
+ k = self.k_norm(self.k_proj(h).view(B, S, NUM_KV_HEADS, HEAD_DIM)).transpose(1, 2)
51
+ v = self.v_proj(h).view(B, S, NUM_KV_HEADS, HEAD_DIM).transpose(1, 2)
52
+ # RoPE
53
+ q = q * cos + rotate_half(q) * sin
54
+ k = k * cos + rotate_half(k) * sin
55
+ # Update KV cache
56
+ kc = kc.clone(); vc = vc.clone()
57
+ kc[:, :, cp, :] = k; vc[:, :, cp, :] = v
58
+ # GQA expand
59
+ cache_len = kc.shape[2]
60
+ ke = kc.unsqueeze(2).repeat(1, 1, KV_GROUPS, 1, 1).reshape(B, NUM_HEADS, cache_len, HEAD_DIM)
61
+ ve = vc.unsqueeze(2).repeat(1, 1, KV_GROUPS, 1, 1).reshape(B, NUM_HEADS, cache_len, HEAD_DIM)
62
+ # Attention
63
+ o = F.scaled_dot_product_attention(q, ke, ve, attn_mask=am, scale=self.scale)
64
+ return self.o_proj(o.transpose(1, 2).reshape(B, S, -1)), kc, vc
65
+
66
+
67
+ class FixedLayer(nn.Module):
68
+ def __init__(self, orig):
69
+ super().__init__()
70
+ self.attn = FixedAttn(orig.self_attn)
71
+ self.n1 = RMSNorm(HIDDEN_SIZE); self.n1.weight = nn.Parameter(orig.input_layernorm.weight.clone())
72
+ self.n2 = RMSNorm(HIDDEN_SIZE); self.n2.weight = nn.Parameter(orig.post_attention_layernorm.weight.clone())
73
+ self.gp = copy.deepcopy(orig.mlp.gate_proj)
74
+ self.up = copy.deepcopy(orig.mlp.up_proj)
75
+ self.dp = copy.deepcopy(orig.mlp.down_proj)
76
+
77
+ def forward(self, h, cos, sin, cp, kc, vc, am):
78
+ r = h; a, kc, vc = self.attn(self.n1(h), cos, sin, cp, kc, vc, am); h = r + a
79
+ r = h; x = self.n2(h); h = r + self.dp(F.silu(self.gp(x)) * self.up(x))
80
+ return h, kc, vc
81
+
82
+
83
+ class FixedTalker(nn.Module):
84
+ def __init__(self, orig_talker):
85
+ super().__init__()
86
+ self.layers = nn.ModuleList([FixedLayer(l) for l in orig_talker.model.layers])
87
+ self.norm = RMSNorm(HIDDEN_SIZE); self.norm.weight = nn.Parameter(orig_talker.model.norm.weight.clone())
88
+ self.codec_head = copy.deepcopy(orig_talker.codec_head)
89
+ self.register_buffer("inv_freq", orig_talker.model.rotary_emb.inv_freq.clone())
90
+ self.rs = getattr(orig_talker.model.rotary_emb, 'attention_scaling', 1.0)
91
+
92
+ def forward(self, ie, pid, cp, am, *kv):
93
+ pos = pid[0].float()
94
+ freqs = pos.unsqueeze(-1) * self.inv_freq.float().unsqueeze(0).unsqueeze(0)
95
+ emb = torch.cat([freqs, freqs], dim=-1)
96
+ cos = (emb.cos() * self.rs).to(ie.dtype).unsqueeze(1)
97
+ sin = (emb.sin() * self.rs).to(ie.dtype).unsqueeze(1)
98
+ h = ie
99
+ ukv = []
100
+ for i, l in enumerate(self.layers):
101
+ h, nk, nv = l(h, cos, sin, cp, kv[i*2], kv[i*2+1], am)
102
+ ukv.append(nk); ukv.append(nv)
103
+ return (self.codec_head(self.norm(h)), *ukv)
104
+
105
+
106
+ def main():
107
+ print("="*60)
108
+ print(f"Token-by-Token Validation (v2, {NUM_STEPS} steps, greedy)")
109
+ print("="*60)
110
+
111
+ from qwen_tts import Qwen3TTSModel
112
+ from transformers import AutoTokenizer
113
+ from transformers.cache_utils import DynamicCache
114
+
115
+ print("\n[1] Loading model...")
116
+ model = Qwen3TTSModel.from_pretrained(
117
+ os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base"),
118
+ device_map="cpu", dtype=torch.float32, attn_implementation="sdpa")
119
+ talker = model.model.talker
120
+ talker.eval()
121
+
122
+ tokenizer = AutoTokenizer.from_pretrained(
123
+ os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base"))
124
+
125
+ # Build input
126
+ text = "Hi"
127
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
128
+ print(f" Text: '{text}' β†’ {text_ids}")
129
+
130
+ # Embeddings
131
+ emb_w = talker.model.text_embedding.weight.data
132
+ codec_w = talker.model.codec_embedding.weight.data
133
+ proj = talker.text_projection
134
+
135
+ raw = F.embedding(torch.tensor(text_ids), emb_w)
136
+ with torch.no_grad():
137
+ text_embeds = proj(raw)
138
+ inputs_embeds = text_embeds.unsqueeze(0) # [1, T, 2048]
139
+ seq_len = inputs_embeds.shape[1]
140
+
141
+ # ── Original talker ──
142
+ print(f"\n[2] Original talker ({NUM_STEPS} steps)...")
143
+ orig_tokens = []
144
+ with torch.no_grad():
145
+ past_kv = DynamicCache()
146
+ pos_ids = torch.arange(seq_len).unsqueeze(0).unsqueeze(0).expand(3, 1, -1)
147
+ cache_pos = torch.arange(seq_len)
148
+ out = talker.model(input_ids=None, inputs_embeds=inputs_embeds,
149
+ position_ids=pos_ids, cache_position=cache_pos,
150
+ attention_mask=torch.ones(1, seq_len),
151
+ past_key_values=past_kv, use_cache=True)
152
+ logits = talker.codec_head(out.last_hidden_state)
153
+ next_token = logits[0, -1].argmax().item()
154
+ orig_tokens.append(next_token)
155
+ past_kv = out.past_key_values
156
+
157
+ for step in range(NUM_STEPS - 1):
158
+ te = F.embedding(torch.tensor([[next_token]]), codec_w)
159
+ pi = torch.tensor([[[seq_len + step]]]).expand(3, 1, 1)
160
+ cp = torch.tensor([seq_len + step])
161
+ out = talker.model(input_ids=None, inputs_embeds=te,
162
+ position_ids=pi, cache_position=cp,
163
+ attention_mask=torch.ones(1, seq_len + step + 1),
164
+ past_key_values=past_kv, use_cache=True)
165
+ logits = talker.codec_head(out.last_hidden_state)
166
+ next_token = logits[0, -1].argmax().item()
167
+ orig_tokens.append(next_token)
168
+ past_kv = out.past_key_values
169
+ print(f" Tokens: {orig_tokens}")
170
+
171
+ # ── Fixed talker wrapper ──
172
+ print(f"\n[3] Building FixedTalker wrapper...")
173
+ t0 = time.time()
174
+ fixed = FixedTalker(talker)
175
+ fixed.eval()
176
+ print(f" Built in {time.time()-t0:.1f}s")
177
+
178
+ # Free original to save RAM
179
+ del talker, model
180
+ import gc; gc.collect()
181
+
182
+ print(f"\n[4] Fixed talker ({NUM_STEPS} steps)...")
183
+ kv = [torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM) for _ in range(NUM_LAYERS * 2)]
184
+ pid = torch.arange(seq_len).unsqueeze(0).unsqueeze(0).expand(3, 1, -1)
185
+ cp = torch.arange(seq_len)
186
+ mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float('-inf'))
187
+ for i in range(seq_len):
188
+ mask[0, 0, i, :i+1] = 0.0
189
+
190
+ fixed_tokens = []
191
+ with torch.no_grad():
192
+ t0 = time.time()
193
+ result = fixed(inputs_embeds, pid, cp, mask, *kv)
194
+ logits = result[0]; kv = list(result[1:])
195
+ next_token = logits[0, -1].argmax().item()
196
+ fixed_tokens.append(next_token)
197
+ print(f" Prefill: {time.time()-t0:.1f}s, token={next_token}", flush=True)
198
+
199
+ for step in range(NUM_STEPS - 1):
200
+ cur = seq_len + step
201
+ te = F.embedding(torch.tensor([[next_token]]), codec_w)
202
+ pi = torch.tensor([[[cur]]]).expand(3, 1, 1)
203
+ cp = torch.tensor([cur])
204
+ dm = torch.full((1, 1, 1, MAX_SEQ_LEN), float('-inf'))
205
+ dm[0, 0, 0, :cur+1] = 0.0
206
+ t1 = time.time()
207
+ result = fixed(te, pi, cp, dm, *kv)
208
+ logits = result[0]; kv = list(result[1:])
209
+ next_token = logits[0, -1].argmax().item()
210
+ fixed_tokens.append(next_token)
211
+ print(f" Step {step+1}: {time.time()-t1:.1f}s, token={next_token}", flush=True)
212
+
213
+ # ── Compare ──
214
+ print("\n" + "="*60)
215
+ print("COMPARISON")
216
+ print("="*60)
217
+ match = 0
218
+ for i in range(NUM_STEPS):
219
+ m = orig_tokens[i] == fixed_tokens[i]
220
+ if m: match += 1
221
+ print(f" Step {i+1:2d}: orig={orig_tokens[i]:5d} fixed={fixed_tokens[i]:5d} {'βœ…' if m else '❌'}")
222
+ print(f"\n Match: {match}/{NUM_STEPS} ({100*match/NUM_STEPS:.0f}%)")
223
+ if match == NUM_STEPS:
224
+ print(" πŸŽ‰ PERFECT β€” Fixed wrapper produces identical tokens!")
225
+ elif match >= NUM_STEPS * 0.8:
226
+ print(" βœ… NEAR-PERFECT β€” minor numerical drift")
227
+ else:
228
+ print(" ❌ DIVERGENCE β€” needs investigation")
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()