acul3 commited on
Commit
77cf118
·
verified ·
1 Parent(s): 846eac7

Upload scripts/export_decoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/export_decoder.py +551 -0
scripts/export_decoder.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Phase 3b: Text Decoder Export for ExecuTorch
4
+ Extracts language_model + lm_head into a standalone nn.Module
5
+ with static KV cache tensors for torch.export compatibility.
6
+
7
+ Architecture: Qwen3 decoder (28 layers, GQA 16/8 heads, head_dim=128)
8
+ Fixed max_seq_len: 512
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ # Model constants from config
19
+ HIDDEN_SIZE = 1024
20
+ NUM_LAYERS = 28
21
+ NUM_HEADS = 16
22
+ NUM_KV_HEADS = 8
23
+ HEAD_DIM = 128
24
+ INTERMEDIATE_SIZE = 3072
25
+ VOCAB_SIZE = 151936
26
+ MAX_SEQ_LEN = 4096
27
+ RMS_EPS = 1e-6
28
+ ROPE_THETA = 1000000.0
29
+ NUM_KV_GROUPS = NUM_HEADS // NUM_KV_HEADS # 2
30
+
31
+ MODEL_DIR = "./models/LightOnOCR-2-1B"
32
+
33
+
34
+ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = RMS_EPS) -> torch.Tensor:
35
+ """Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator."""
36
+ input_dtype = x.dtype
37
+ x = x.to(torch.float32)
38
+ variance = x.pow(2).mean(-1, keepdim=True)
39
+ x = x * torch.rsqrt(variance + eps)
40
+ return weight * x.to(input_dtype)
41
+
42
+
43
+ def precompute_rope_freqs(max_seq_len: int, head_dim: int, theta: float = ROPE_THETA):
44
+ """Precompute RoPE cos/sin for all positions up to max_seq_len."""
45
+ freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
46
+ t = torch.arange(max_seq_len, dtype=torch.float32)
47
+ freqs = torch.outer(t, freqs)
48
+ cos = freqs.cos()
49
+ sin = freqs.sin()
50
+ # Duplicate for full head_dim: [seq_len, head_dim/2] -> [seq_len, head_dim]
51
+ cos = torch.cat([cos, cos], dim=-1)
52
+ sin = torch.cat([sin, sin], dim=-1)
53
+ return cos, sin # [max_seq_len, head_dim]
54
+
55
+
56
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
57
+ """
58
+ Apply rotary position embeddings to query and key states.
59
+ q, k: [batch, num_heads, seq_len, head_dim]
60
+ cos, sin: [max_seq_len, head_dim]
61
+ position_ids: [batch, seq_len]
62
+ """
63
+ # Gather cos/sin for the given positions
64
+ cos = cos[position_ids].unsqueeze(1) # [batch, 1, seq_len, head_dim]
65
+ sin = sin[position_ids].unsqueeze(1) # [batch, 1, seq_len, head_dim]
66
+
67
+ # Rotate
68
+ q_embed = (q * cos) + (rotate_half(q) * sin)
69
+ k_embed = (k * cos) + (rotate_half(k) * sin)
70
+ return q_embed, k_embed
71
+
72
+
73
+ def rotate_half(x):
74
+ """Rotates half the hidden dims of the input."""
75
+ x1 = x[..., : x.shape[-1] // 2]
76
+ x2 = x[..., x.shape[-1] // 2 :]
77
+ return torch.cat((-x2, x1), dim=-1)
78
+
79
+
80
+ class Qwen3AttentionFixed(nn.Module):
81
+ """
82
+ Fixed Qwen3 attention with static KV cache, inline QK-norm, and
83
+ no dynamic dispatch. Designed for torch.export compatibility.
84
+ """
85
+
86
+ def __init__(self, layer_idx: int):
87
+ super().__init__()
88
+ self.layer_idx = layer_idx
89
+ self.scaling = HEAD_DIM ** -0.5
90
+
91
+ # Projections
92
+ self.q_proj = nn.Linear(HIDDEN_SIZE, NUM_HEADS * HEAD_DIM, bias=False)
93
+ self.k_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False)
94
+ self.v_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False)
95
+ self.o_proj = nn.Linear(NUM_HEADS * HEAD_DIM, HIDDEN_SIZE, bias=False)
96
+
97
+ # QK-norm weights (RMSNorm per head)
98
+ self.q_norm_weight = nn.Parameter(torch.ones(HEAD_DIM))
99
+ self.k_norm_weight = nn.Parameter(torch.ones(HEAD_DIM))
100
+
101
+ def forward(
102
+ self,
103
+ hidden_states: torch.Tensor, # [batch, seq_len, hidden_size]
104
+ cos: torch.Tensor, # [max_seq_len, head_dim]
105
+ sin: torch.Tensor, # [max_seq_len, head_dim]
106
+ position_ids: torch.Tensor, # [batch, seq_len]
107
+ attention_mask: torch.Tensor, # [batch, 1, seq_len, cache_len+seq_len]
108
+ k_cache: torch.Tensor, # [batch, num_kv_heads, max_seq_len, head_dim]
109
+ v_cache: torch.Tensor, # [batch, num_kv_heads, max_seq_len, head_dim]
110
+ cache_position: torch.Tensor, # [seq_len] — positions to write into cache
111
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
112
+ """Returns (output, updated_k_cache, updated_v_cache)"""
113
+ batch, seq_len, _ = hidden_states.shape
114
+
115
+ # Project Q, K, V
116
+ q = self.q_proj(hidden_states)
117
+ k = self.k_proj(hidden_states)
118
+ v = self.v_proj(hidden_states)
119
+
120
+ # Reshape: [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim]
121
+ q = q.view(batch, seq_len, NUM_HEADS, HEAD_DIM)
122
+ k = k.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM)
123
+ v = v.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM)
124
+
125
+ # Apply QK-norm (RMSNorm per head, inline)
126
+ q = rms_norm(q, self.q_norm_weight)
127
+ k = rms_norm(k, self.k_norm_weight)
128
+
129
+ q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
130
+ k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
131
+ v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
132
+
133
+ # Apply RoPE
134
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
135
+
136
+ # Update KV cache using scatter (index_put)
137
+ # cache_position: [seq_len] — the positions to update
138
+ # k_cache shape: [batch, num_kv_heads, max_seq_len, head_dim]
139
+ k_cache = k_cache.clone()
140
+ v_cache = v_cache.clone()
141
+ k_cache[:, :, cache_position, :] = k
142
+ v_cache[:, :, cache_position, :] = v
143
+
144
+ # Expand KV heads for GQA: repeat each KV head for its group of Q heads
145
+ cache_len = k_cache.shape[2] # dynamic, works for any MAX_SEQ_LEN
146
+ k_expanded = k_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1)
147
+ k_expanded = k_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM)
148
+ v_expanded = v_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1)
149
+ v_expanded = v_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM)
150
+
151
+ # Attention: Q @ K^T / sqrt(head_dim)
152
+ attn_weights = torch.matmul(q, k_expanded.transpose(2, 3)) * self.scaling
153
+
154
+ # Apply attention mask
155
+ attn_weights = attn_weights + attention_mask
156
+
157
+ # Softmax
158
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
159
+
160
+ # Attention output
161
+ attn_output = torch.matmul(attn_weights, v_expanded)
162
+
163
+ # Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, hidden_size]
164
+ attn_output = attn_output.transpose(1, 2).contiguous()
165
+ attn_output = attn_output.reshape(batch, seq_len, -1)
166
+
167
+ # Output projection
168
+ attn_output = self.o_proj(attn_output)
169
+
170
+ return attn_output, k_cache, v_cache
171
+
172
+
173
+ class Qwen3MLPFixed(nn.Module):
174
+ """Fixed Qwen3 MLP (SiLU gate + up projection)."""
175
+
176
+ def __init__(self):
177
+ super().__init__()
178
+ self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
179
+ self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
180
+ self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
184
+
185
+
186
+ class Qwen3DecoderLayerFixed(nn.Module):
187
+ """Fixed Qwen3 decoder layer with static KV cache."""
188
+
189
+ def __init__(self, layer_idx: int):
190
+ super().__init__()
191
+ self.self_attn = Qwen3AttentionFixed(layer_idx)
192
+ self.mlp = Qwen3MLPFixed()
193
+ self.input_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE))
194
+ self.post_attention_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE))
195
+
196
+ def forward(
197
+ self,
198
+ hidden_states: torch.Tensor,
199
+ cos: torch.Tensor,
200
+ sin: torch.Tensor,
201
+ position_ids: torch.Tensor,
202
+ attention_mask: torch.Tensor,
203
+ k_cache: torch.Tensor,
204
+ v_cache: torch.Tensor,
205
+ cache_position: torch.Tensor,
206
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
207
+ # Pre-norm + self attention
208
+ residual = hidden_states
209
+ hidden_states = rms_norm(hidden_states, self.input_layernorm_weight)
210
+ hidden_states, k_cache, v_cache = self.self_attn(
211
+ hidden_states, cos, sin, position_ids, attention_mask,
212
+ k_cache, v_cache, cache_position
213
+ )
214
+ hidden_states = residual + hidden_states
215
+
216
+ # Pre-norm + MLP
217
+ residual = hidden_states
218
+ hidden_states = rms_norm(hidden_states, self.post_attention_layernorm_weight)
219
+ hidden_states = self.mlp(hidden_states)
220
+ hidden_states = residual + hidden_states
221
+
222
+ return hidden_states, k_cache, v_cache
223
+
224
+
225
+ class TextDecoderFixed(nn.Module):
226
+ """
227
+ Complete text decoder for ExecuTorch export.
228
+ Includes embedding, all decoder layers with static KV cache, and LM head.
229
+
230
+ For prefill: input_ids has seq_len > 1, cache_position starts at 0
231
+ For decode: input_ids has seq_len = 1, cache_position = current position
232
+ """
233
+
234
+ def __init__(self):
235
+ super().__init__()
236
+ self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
237
+ self.layers = nn.ModuleList([
238
+ Qwen3DecoderLayerFixed(i) for i in range(NUM_LAYERS)
239
+ ])
240
+ self.norm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE))
241
+ self.lm_head = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE, bias=False)
242
+
243
+ # Pre-compute RoPE frequencies
244
+ cos, sin = precompute_rope_freqs(MAX_SEQ_LEN, HEAD_DIM, ROPE_THETA)
245
+ self.register_buffer("rope_cos", cos)
246
+ self.register_buffer("rope_sin", sin)
247
+
248
+ def forward(
249
+ self,
250
+ input_ids: torch.Tensor, # [batch, seq_len]
251
+ attention_mask: torch.Tensor, # [batch, 1, seq_len, max_seq_len]
252
+ position_ids: torch.Tensor, # [batch, seq_len]
253
+ cache_position: torch.Tensor, # [seq_len]
254
+ *kv_caches: torch.Tensor, # 28 * (k_cache, v_cache) flattened
255
+ ) -> tuple:
256
+ """
257
+ Returns: (logits, *updated_kv_caches)
258
+ kv_caches: 56 tensors total (28 layers * 2 for k,v)
259
+ Each cache: [batch, num_kv_heads, max_seq_len, head_dim]
260
+ """
261
+ # Embed tokens
262
+ hidden_states = self.embed_tokens(input_ids)
263
+
264
+ # Process through all layers, updating KV caches
265
+ updated_caches = []
266
+ for i, layer in enumerate(self.layers):
267
+ k_cache = kv_caches[i * 2]
268
+ v_cache = kv_caches[i * 2 + 1]
269
+ hidden_states, new_k, new_v = layer(
270
+ hidden_states,
271
+ self.rope_cos, self.rope_sin,
272
+ position_ids, attention_mask,
273
+ k_cache, v_cache, cache_position
274
+ )
275
+ updated_caches.append(new_k)
276
+ updated_caches.append(new_v)
277
+
278
+ # Final norm
279
+ hidden_states = rms_norm(hidden_states, self.norm_weight)
280
+
281
+ # LM head — only compute logits for the last token
282
+ logits = self.lm_head(hidden_states[:, -1:, :]) # [batch, 1, vocab_size]
283
+
284
+ return (logits, *updated_caches)
285
+
286
+
287
+ def load_original_model():
288
+ """Load the original model with proper weight remapping."""
289
+ from transformers import AutoModelForImageTextToText
290
+ from safetensors.torch import load_file
291
+
292
+ print("Loading original model...")
293
+ model = AutoModelForImageTextToText.from_pretrained(
294
+ MODEL_DIR,
295
+ dtype=torch.bfloat16,
296
+ attn_implementation="sdpa",
297
+ device_map="cpu",
298
+ )
299
+
300
+ state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
301
+ remapped = {}
302
+ for k, v in state_dict.items():
303
+ new_k = k.replace("model.vision_encoder.", "model.vision_tower.")
304
+ new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.")
305
+ remapped[new_k] = v
306
+ model.load_state_dict(remapped, strict=False)
307
+
308
+ return model
309
+
310
+
311
+ def build_decoder_module(original_model):
312
+ """Build the fixed decoder module from the original model's weights."""
313
+ print("\nBuilding fixed text decoder...")
314
+
315
+ orig_lm = original_model.model.language_model
316
+ orig_lm_head = original_model.lm_head
317
+
318
+ decoder = TextDecoderFixed()
319
+
320
+ # Copy embedding weights
321
+ decoder.embed_tokens.weight.data.copy_(orig_lm.embed_tokens.weight.data)
322
+
323
+ # Copy final norm weight
324
+ decoder.norm_weight.data.copy_(orig_lm.norm.weight.data)
325
+
326
+ # Copy LM head (tied with embeddings)
327
+ decoder.lm_head.weight.data.copy_(orig_lm.embed_tokens.weight.data)
328
+
329
+ # Copy layer weights
330
+ for i in range(NUM_LAYERS):
331
+ orig_layer = orig_lm.layers[i]
332
+ fixed_layer = decoder.layers[i]
333
+
334
+ # Attention projections
335
+ fixed_layer.self_attn.q_proj.weight.data.copy_(orig_layer.self_attn.q_proj.weight.data)
336
+ fixed_layer.self_attn.k_proj.weight.data.copy_(orig_layer.self_attn.k_proj.weight.data)
337
+ fixed_layer.self_attn.v_proj.weight.data.copy_(orig_layer.self_attn.v_proj.weight.data)
338
+ fixed_layer.self_attn.o_proj.weight.data.copy_(orig_layer.self_attn.o_proj.weight.data)
339
+
340
+ # QK-norm weights
341
+ fixed_layer.self_attn.q_norm_weight.data.copy_(orig_layer.self_attn.q_norm.weight.data)
342
+ fixed_layer.self_attn.k_norm_weight.data.copy_(orig_layer.self_attn.k_norm.weight.data)
343
+
344
+ # Layer norms
345
+ fixed_layer.input_layernorm_weight.data.copy_(orig_layer.input_layernorm.weight.data)
346
+ fixed_layer.post_attention_layernorm_weight.data.copy_(orig_layer.post_attention_layernorm.weight.data)
347
+
348
+ # MLP
349
+ fixed_layer.mlp.gate_proj.weight.data.copy_(orig_layer.mlp.gate_proj.weight.data)
350
+ fixed_layer.mlp.up_proj.weight.data.copy_(orig_layer.mlp.up_proj.weight.data)
351
+ fixed_layer.mlp.down_proj.weight.data.copy_(orig_layer.mlp.down_proj.weight.data)
352
+
353
+ decoder.eval()
354
+ total_params = sum(p.numel() for p in decoder.parameters())
355
+ print(f" Decoder parameters: {total_params/1e6:.2f}M")
356
+
357
+ return decoder
358
+
359
+
360
+ def create_empty_kv_caches(batch_size: int = 1, dtype=torch.float32, device="cpu"):
361
+ """Create empty KV cache tensors for all layers."""
362
+ caches = []
363
+ for _ in range(NUM_LAYERS):
364
+ k = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
365
+ v = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
366
+ caches.extend([k, v])
367
+ return tuple(caches)
368
+
369
+
370
+ def create_causal_mask(seq_len: int, cache_len: int = MAX_SEQ_LEN, dtype=torch.float32):
371
+ """Create causal attention mask."""
372
+ mask = torch.full((seq_len, cache_len), float("-inf"), dtype=dtype)
373
+ mask = torch.triu(mask, diagonal=cache_len - seq_len + 1)
374
+ return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, cache_len]
375
+
376
+
377
+ def test_decoder_module(decoder, original_model):
378
+ """Test that the fixed decoder produces same output as original."""
379
+ print("\nTesting decoder output consistency...")
380
+
381
+ device = "cuda" if torch.cuda.is_available() else "cpu"
382
+ decoder = decoder.to(device).to(torch.bfloat16)
383
+ original_model = original_model.to(device)
384
+
385
+ # Test input
386
+ input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=device)
387
+ seq_len = input_ids.shape[1]
388
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
389
+ cache_position = torch.arange(seq_len, device=device)
390
+
391
+ # Causal mask
392
+ mask = create_causal_mask(seq_len, dtype=torch.bfloat16).to(device)
393
+
394
+ # Empty KV caches
395
+ kv_caches = create_empty_kv_caches(1, torch.bfloat16, device)
396
+
397
+ with torch.no_grad():
398
+ # Fixed decoder
399
+ result = decoder(input_ids, mask, position_ids, cache_position, *kv_caches)
400
+ fixed_logits = result[0]
401
+ print(f" Fixed decoder output shape: {fixed_logits.shape}")
402
+
403
+ # Original model (text-only, no image)
404
+ orig_outputs = original_model(
405
+ input_ids=input_ids,
406
+ attention_mask=torch.ones_like(input_ids),
407
+ use_cache=False,
408
+ )
409
+ orig_logits = orig_outputs.logits[:, -1:, :]
410
+ print(f" Original model output shape: {orig_logits.shape}")
411
+
412
+ # Compare
413
+ diff = (fixed_logits.float() - orig_logits.float()).abs()
414
+ print(f" Max absolute difference: {diff.max().item():.6f}")
415
+ print(f" Mean absolute difference: {diff.mean().item():.6f}")
416
+
417
+ # Check top-k predictions match
418
+ fixed_topk = fixed_logits.float().topk(5, dim=-1)
419
+ orig_topk = orig_logits.float().topk(5, dim=-1)
420
+ print(f" Fixed top-5 token IDs: {fixed_topk.indices[0, 0].tolist()}")
421
+ print(f" Original top-5 token IDs: {orig_topk.indices[0, 0].tolist()}")
422
+ matching = sum(1 for t in fixed_topk.indices[0, 0].tolist() if t in orig_topk.indices[0, 0].tolist())
423
+ print(f" Top-5 overlap: {matching}/5")
424
+
425
+
426
+ def try_torch_export(decoder):
427
+ """Attempt torch.export.export() on the decoder."""
428
+ print("\n" + "=" * 60)
429
+ print("ATTEMPTING torch.export.export() on decoder")
430
+ print("=" * 60)
431
+
432
+ # Export on CPU with float32 for XNNPACK
433
+ decoder = decoder.to("cpu").to(torch.float32)
434
+ decoder.eval()
435
+
436
+ batch_size = 1
437
+ seq_len = 1 # Export for single-token decode step (simpler)
438
+
439
+ input_ids = torch.randint(0, VOCAB_SIZE, (batch_size, seq_len))
440
+ attention_mask = create_causal_mask(seq_len, MAX_SEQ_LEN, torch.float32)
441
+ position_ids = torch.zeros(batch_size, seq_len, dtype=torch.long)
442
+ cache_position = torch.zeros(seq_len, dtype=torch.long)
443
+ kv_caches = create_empty_kv_caches(batch_size, torch.float32, "cpu")
444
+
445
+ example_args = (input_ids, attention_mask, position_ids, cache_position, *kv_caches)
446
+
447
+ try:
448
+ print(f" Exporting with seq_len={seq_len}, max_cache={MAX_SEQ_LEN}...")
449
+ print(f" Number of input tensors: {len(example_args)} (4 + {NUM_LAYERS}*2 KV caches)")
450
+ exported = torch.export.export(
451
+ decoder,
452
+ example_args,
453
+ strict=False,
454
+ )
455
+ print(" SUCCESS! torch.export completed!")
456
+ return exported
457
+
458
+ except Exception as e:
459
+ print(f" FAILED: {type(e).__name__}: {e}")
460
+ import traceback
461
+ traceback.print_exc()
462
+
463
+ # Try with trace as fallback
464
+ print("\n Trying torch.jit.trace as fallback...")
465
+ try:
466
+ traced = torch.jit.trace(decoder, example_args)
467
+ print(" torch.jit.trace succeeded!")
468
+ return traced
469
+ except Exception as e2:
470
+ print(f" torch.jit.trace also failed: {type(e2).__name__}: {e2}")
471
+
472
+ return None
473
+
474
+
475
+ def export_to_pte(exported_model):
476
+ """Convert exported model to .pte using XNNPACK backend."""
477
+ print("\n" + "=" * 60)
478
+ print("EXPORTING DECODER TO .pte (XNNPACK)")
479
+ print("=" * 60)
480
+
481
+ try:
482
+ from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
483
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
484
+
485
+ if not hasattr(exported_model, 'graph_module'):
486
+ print(" Need torch.export.export() result for .pte export")
487
+ return None
488
+
489
+ print(" Running to_edge_transform_and_lower...")
490
+ edge = to_edge_transform_and_lower(
491
+ exported_model,
492
+ compile_config=EdgeCompileConfig(_check_ir_validity=False),
493
+ partitioner=[XnnpackPartitioner()],
494
+ )
495
+
496
+ print(" Running to_executorch()...")
497
+ pte = edge.to_executorch()
498
+
499
+ output_path = "text_decoder.pte"
500
+ with open(output_path, "wb") as f:
501
+ f.write(pte.buffer)
502
+
503
+ file_size = os.path.getsize(output_path) / (1024 * 1024)
504
+ print(f" Saved to {output_path} ({file_size:.1f} MB)")
505
+ return output_path
506
+
507
+ except ImportError as e:
508
+ print(f" ExecuTorch import failed: {e}")
509
+ return None
510
+ except Exception as e:
511
+ print(f" Export failed: {type(e).__name__}: {e}")
512
+ import traceback
513
+ traceback.print_exc()
514
+ return None
515
+
516
+
517
+ def main():
518
+ print("=" * 60)
519
+ print("Text Decoder Export for ExecuTorch")
520
+ print(f"Architecture: Qwen3 {NUM_LAYERS}L, {NUM_HEADS}H/{NUM_KV_HEADS}KV, dim={HIDDEN_SIZE}")
521
+ print(f"Max seq len: {MAX_SEQ_LEN}")
522
+ print(f"KV cache size per layer: {NUM_KV_HEADS}x{MAX_SEQ_LEN}x{HEAD_DIM} = {NUM_KV_HEADS*MAX_SEQ_LEN*HEAD_DIM/1e6:.2f}M elements")
523
+ print("=" * 60)
524
+
525
+ # Load original model
526
+ original_model = load_original_model()
527
+
528
+ # Build fixed decoder
529
+ decoder = build_decoder_module(original_model)
530
+
531
+ # Test consistency
532
+ test_decoder_module(decoder, original_model)
533
+
534
+ # Free original model memory
535
+ del original_model
536
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
537
+
538
+ # Try torch.export
539
+ exported = try_torch_export(decoder)
540
+
541
+ if exported is not None:
542
+ export_to_pte(exported)
543
+
544
+ # Save the PyTorch module for later use
545
+ torch.save(decoder.state_dict(), "text_decoder_fixed.pt")
546
+ print(f"\nSaved fixed decoder state dict to text_decoder_fixed.pt")
547
+ print("Decoder export script complete!")
548
+
549
+
550
+ if __name__ == "__main__":
551
+ main()