| |
| """ |
| Phase 3b: Text Decoder Export for ExecuTorch |
| Extracts language_model + lm_head into a standalone nn.Module |
| with static KV cache tensors for torch.export compatibility. |
| |
| Architecture: Qwen3 decoder (28 layers, GQA 16/8 heads, head_dim=128) |
| Fixed max_seq_len: 512 |
| """ |
|
|
| import os |
| import sys |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| HIDDEN_SIZE = 1024 |
| NUM_LAYERS = 28 |
| NUM_HEADS = 16 |
| NUM_KV_HEADS = 8 |
| HEAD_DIM = 128 |
| INTERMEDIATE_SIZE = 3072 |
| VOCAB_SIZE = 151936 |
| MAX_SEQ_LEN = 4096 |
| RMS_EPS = 1e-6 |
| ROPE_THETA = 1000000.0 |
| NUM_KV_GROUPS = NUM_HEADS // NUM_KV_HEADS |
|
|
| MODEL_DIR = "./models/LightOnOCR-2-1B" |
|
|
|
|
| def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = RMS_EPS) -> torch.Tensor: |
| """Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator.""" |
| input_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + eps) |
| return weight * x.to(input_dtype) |
|
|
|
|
| def precompute_rope_freqs(max_seq_len: int, head_dim: int, theta: float = ROPE_THETA): |
| """Precompute RoPE cos/sin for all positions up to max_seq_len.""" |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| cos = freqs.cos() |
| sin = freqs.sin() |
| |
| cos = torch.cat([cos, cos], dim=-1) |
| sin = torch.cat([sin, sin], dim=-1) |
| return cos, sin |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| """ |
| Apply rotary position embeddings to query and key states. |
| q, k: [batch, num_heads, seq_len, head_dim] |
| cos, sin: [max_seq_len, head_dim] |
| position_ids: [batch, seq_len] |
| """ |
| |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
|
|
| |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| class Qwen3AttentionFixed(nn.Module): |
| """ |
| Fixed Qwen3 attention with static KV cache, inline QK-norm, and |
| no dynamic dispatch. Designed for torch.export compatibility. |
| """ |
|
|
| def __init__(self, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.scaling = HEAD_DIM ** -0.5 |
|
|
| |
| self.q_proj = nn.Linear(HIDDEN_SIZE, NUM_HEADS * HEAD_DIM, bias=False) |
| self.k_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False) |
| self.v_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False) |
| self.o_proj = nn.Linear(NUM_HEADS * HEAD_DIM, HIDDEN_SIZE, bias=False) |
|
|
| |
| self.q_norm_weight = nn.Parameter(torch.ones(HEAD_DIM)) |
| self.k_norm_weight = nn.Parameter(torch.ones(HEAD_DIM)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| position_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| k_cache: torch.Tensor, |
| v_cache: torch.Tensor, |
| cache_position: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Returns (output, updated_k_cache, updated_v_cache)""" |
| batch, seq_len, _ = hidden_states.shape |
|
|
| |
| q = self.q_proj(hidden_states) |
| k = self.k_proj(hidden_states) |
| v = self.v_proj(hidden_states) |
|
|
| |
| q = q.view(batch, seq_len, NUM_HEADS, HEAD_DIM) |
| k = k.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM) |
| v = v.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM) |
|
|
| |
| q = rms_norm(q, self.q_norm_weight) |
| k = rms_norm(k, self.k_norm_weight) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) |
|
|
| |
| |
| |
| k_cache = k_cache.clone() |
| v_cache = v_cache.clone() |
| k_cache[:, :, cache_position, :] = k |
| v_cache[:, :, cache_position, :] = v |
|
|
| |
| cache_len = k_cache.shape[2] |
| k_expanded = k_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1) |
| k_expanded = k_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM) |
| v_expanded = v_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1) |
| v_expanded = v_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM) |
|
|
| |
| attn_weights = torch.matmul(q, k_expanded.transpose(2, 3)) * self.scaling |
|
|
| |
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) |
|
|
| |
| attn_output = torch.matmul(attn_weights, v_expanded) |
|
|
| |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(batch, seq_len, -1) |
|
|
| |
| attn_output = self.o_proj(attn_output) |
|
|
| return attn_output, k_cache, v_cache |
|
|
|
|
| class Qwen3MLPFixed(nn.Module): |
| """Fixed Qwen3 MLP (SiLU gate + up projection).""" |
|
|
| def __init__(self): |
| super().__init__() |
| self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False) |
| self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False) |
| self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class Qwen3DecoderLayerFixed(nn.Module): |
| """Fixed Qwen3 decoder layer with static KV cache.""" |
|
|
| def __init__(self, layer_idx: int): |
| super().__init__() |
| self.self_attn = Qwen3AttentionFixed(layer_idx) |
| self.mlp = Qwen3MLPFixed() |
| self.input_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE)) |
| self.post_attention_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| position_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| k_cache: torch.Tensor, |
| v_cache: torch.Tensor, |
| cache_position: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| residual = hidden_states |
| hidden_states = rms_norm(hidden_states, self.input_layernorm_weight) |
| hidden_states, k_cache, v_cache = self.self_attn( |
| hidden_states, cos, sin, position_ids, attention_mask, |
| k_cache, v_cache, cache_position |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = rms_norm(hidden_states, self.post_attention_layernorm_weight) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states, k_cache, v_cache |
|
|
|
|
| class TextDecoderFixed(nn.Module): |
| """ |
| Complete text decoder for ExecuTorch export. |
| Includes embedding, all decoder layers with static KV cache, and LM head. |
| |
| For prefill: input_ids has seq_len > 1, cache_position starts at 0 |
| For decode: input_ids has seq_len = 1, cache_position = current position |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) |
| self.layers = nn.ModuleList([ |
| Qwen3DecoderLayerFixed(i) for i in range(NUM_LAYERS) |
| ]) |
| self.norm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE)) |
| self.lm_head = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE, bias=False) |
|
|
| |
| cos, sin = precompute_rope_freqs(MAX_SEQ_LEN, HEAD_DIM, ROPE_THETA) |
| self.register_buffer("rope_cos", cos) |
| self.register_buffer("rope_sin", sin) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| cache_position: torch.Tensor, |
| *kv_caches: torch.Tensor, |
| ) -> tuple: |
| """ |
| Returns: (logits, *updated_kv_caches) |
| kv_caches: 56 tensors total (28 layers * 2 for k,v) |
| Each cache: [batch, num_kv_heads, max_seq_len, head_dim] |
| """ |
| |
| hidden_states = self.embed_tokens(input_ids) |
|
|
| |
| updated_caches = [] |
| for i, layer in enumerate(self.layers): |
| k_cache = kv_caches[i * 2] |
| v_cache = kv_caches[i * 2 + 1] |
| hidden_states, new_k, new_v = layer( |
| hidden_states, |
| self.rope_cos, self.rope_sin, |
| position_ids, attention_mask, |
| k_cache, v_cache, cache_position |
| ) |
| updated_caches.append(new_k) |
| updated_caches.append(new_v) |
|
|
| |
| hidden_states = rms_norm(hidden_states, self.norm_weight) |
|
|
| |
| logits = self.lm_head(hidden_states[:, -1:, :]) |
|
|
| return (logits, *updated_caches) |
|
|
|
|
| def load_original_model(): |
| """Load the original model with proper weight remapping.""" |
| from transformers import AutoModelForImageTextToText |
| from safetensors.torch import load_file |
|
|
| print("Loading original model...") |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_DIR, |
| dtype=torch.bfloat16, |
| attn_implementation="sdpa", |
| device_map="cpu", |
| ) |
|
|
| state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) |
| remapped = {} |
| for k, v in state_dict.items(): |
| new_k = k.replace("model.vision_encoder.", "model.vision_tower.") |
| new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.") |
| remapped[new_k] = v |
| model.load_state_dict(remapped, strict=False) |
|
|
| return model |
|
|
|
|
| def build_decoder_module(original_model): |
| """Build the fixed decoder module from the original model's weights.""" |
| print("\nBuilding fixed text decoder...") |
|
|
| orig_lm = original_model.model.language_model |
| orig_lm_head = original_model.lm_head |
|
|
| decoder = TextDecoderFixed() |
|
|
| |
| decoder.embed_tokens.weight.data.copy_(orig_lm.embed_tokens.weight.data) |
|
|
| |
| decoder.norm_weight.data.copy_(orig_lm.norm.weight.data) |
|
|
| |
| decoder.lm_head.weight.data.copy_(orig_lm.embed_tokens.weight.data) |
|
|
| |
| for i in range(NUM_LAYERS): |
| orig_layer = orig_lm.layers[i] |
| fixed_layer = decoder.layers[i] |
|
|
| |
| fixed_layer.self_attn.q_proj.weight.data.copy_(orig_layer.self_attn.q_proj.weight.data) |
| fixed_layer.self_attn.k_proj.weight.data.copy_(orig_layer.self_attn.k_proj.weight.data) |
| fixed_layer.self_attn.v_proj.weight.data.copy_(orig_layer.self_attn.v_proj.weight.data) |
| fixed_layer.self_attn.o_proj.weight.data.copy_(orig_layer.self_attn.o_proj.weight.data) |
|
|
| |
| fixed_layer.self_attn.q_norm_weight.data.copy_(orig_layer.self_attn.q_norm.weight.data) |
| fixed_layer.self_attn.k_norm_weight.data.copy_(orig_layer.self_attn.k_norm.weight.data) |
|
|
| |
| fixed_layer.input_layernorm_weight.data.copy_(orig_layer.input_layernorm.weight.data) |
| fixed_layer.post_attention_layernorm_weight.data.copy_(orig_layer.post_attention_layernorm.weight.data) |
|
|
| |
| fixed_layer.mlp.gate_proj.weight.data.copy_(orig_layer.mlp.gate_proj.weight.data) |
| fixed_layer.mlp.up_proj.weight.data.copy_(orig_layer.mlp.up_proj.weight.data) |
| fixed_layer.mlp.down_proj.weight.data.copy_(orig_layer.mlp.down_proj.weight.data) |
|
|
| decoder.eval() |
| total_params = sum(p.numel() for p in decoder.parameters()) |
| print(f" Decoder parameters: {total_params/1e6:.2f}M") |
|
|
| return decoder |
|
|
|
|
| def create_empty_kv_caches(batch_size: int = 1, dtype=torch.float32, device="cpu"): |
| """Create empty KV cache tensors for all layers.""" |
| caches = [] |
| for _ in range(NUM_LAYERS): |
| k = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) |
| v = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) |
| caches.extend([k, v]) |
| return tuple(caches) |
|
|
|
|
| def create_causal_mask(seq_len: int, cache_len: int = MAX_SEQ_LEN, dtype=torch.float32): |
| """Create causal attention mask.""" |
| mask = torch.full((seq_len, cache_len), float("-inf"), dtype=dtype) |
| mask = torch.triu(mask, diagonal=cache_len - seq_len + 1) |
| return mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def test_decoder_module(decoder, original_model): |
| """Test that the fixed decoder produces same output as original.""" |
| print("\nTesting decoder output consistency...") |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| decoder = decoder.to(device).to(torch.bfloat16) |
| original_model = original_model.to(device) |
|
|
| |
| input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=device) |
| seq_len = input_ids.shape[1] |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
| cache_position = torch.arange(seq_len, device=device) |
|
|
| |
| mask = create_causal_mask(seq_len, dtype=torch.bfloat16).to(device) |
|
|
| |
| kv_caches = create_empty_kv_caches(1, torch.bfloat16, device) |
|
|
| with torch.no_grad(): |
| |
| result = decoder(input_ids, mask, position_ids, cache_position, *kv_caches) |
| fixed_logits = result[0] |
| print(f" Fixed decoder output shape: {fixed_logits.shape}") |
|
|
| |
| orig_outputs = original_model( |
| input_ids=input_ids, |
| attention_mask=torch.ones_like(input_ids), |
| use_cache=False, |
| ) |
| orig_logits = orig_outputs.logits[:, -1:, :] |
| print(f" Original model output shape: {orig_logits.shape}") |
|
|
| |
| diff = (fixed_logits.float() - orig_logits.float()).abs() |
| print(f" Max absolute difference: {diff.max().item():.6f}") |
| print(f" Mean absolute difference: {diff.mean().item():.6f}") |
|
|
| |
| fixed_topk = fixed_logits.float().topk(5, dim=-1) |
| orig_topk = orig_logits.float().topk(5, dim=-1) |
| print(f" Fixed top-5 token IDs: {fixed_topk.indices[0, 0].tolist()}") |
| print(f" Original top-5 token IDs: {orig_topk.indices[0, 0].tolist()}") |
| matching = sum(1 for t in fixed_topk.indices[0, 0].tolist() if t in orig_topk.indices[0, 0].tolist()) |
| print(f" Top-5 overlap: {matching}/5") |
|
|
|
|
| def try_torch_export(decoder): |
| """Attempt torch.export.export() on the decoder.""" |
| print("\n" + "=" * 60) |
| print("ATTEMPTING torch.export.export() on decoder") |
| print("=" * 60) |
|
|
| |
| decoder = decoder.to("cpu").to(torch.float32) |
| decoder.eval() |
|
|
| batch_size = 1 |
| seq_len = 1 |
|
|
| input_ids = torch.randint(0, VOCAB_SIZE, (batch_size, seq_len)) |
| attention_mask = create_causal_mask(seq_len, MAX_SEQ_LEN, torch.float32) |
| position_ids = torch.zeros(batch_size, seq_len, dtype=torch.long) |
| cache_position = torch.zeros(seq_len, dtype=torch.long) |
| kv_caches = create_empty_kv_caches(batch_size, torch.float32, "cpu") |
|
|
| example_args = (input_ids, attention_mask, position_ids, cache_position, *kv_caches) |
|
|
| try: |
| print(f" Exporting with seq_len={seq_len}, max_cache={MAX_SEQ_LEN}...") |
| print(f" Number of input tensors: {len(example_args)} (4 + {NUM_LAYERS}*2 KV caches)") |
| exported = torch.export.export( |
| decoder, |
| example_args, |
| strict=False, |
| ) |
| print(" SUCCESS! torch.export completed!") |
| return exported |
|
|
| except Exception as e: |
| print(f" FAILED: {type(e).__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| print("\n Trying torch.jit.trace as fallback...") |
| try: |
| traced = torch.jit.trace(decoder, example_args) |
| print(" torch.jit.trace succeeded!") |
| return traced |
| except Exception as e2: |
| print(f" torch.jit.trace also failed: {type(e2).__name__}: {e2}") |
|
|
| return None |
|
|
|
|
| def export_to_pte(exported_model): |
| """Convert exported model to .pte using XNNPACK backend.""" |
| print("\n" + "=" * 60) |
| print("EXPORTING DECODER TO .pte (XNNPACK)") |
| print("=" * 60) |
|
|
| try: |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| if not hasattr(exported_model, 'graph_module'): |
| print(" Need torch.export.export() result for .pte export") |
| return None |
|
|
| print(" Running to_edge_transform_and_lower...") |
| edge = to_edge_transform_and_lower( |
| exported_model, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
|
|
| print(" Running to_executorch()...") |
| pte = edge.to_executorch() |
|
|
| output_path = "text_decoder.pte" |
| with open(output_path, "wb") as f: |
| f.write(pte.buffer) |
|
|
| file_size = os.path.getsize(output_path) / (1024 * 1024) |
| print(f" Saved to {output_path} ({file_size:.1f} MB)") |
| return output_path |
|
|
| except ImportError as e: |
| print(f" ExecuTorch import failed: {e}") |
| return None |
| except Exception as e: |
| print(f" Export failed: {type(e).__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Text Decoder Export for ExecuTorch") |
| print(f"Architecture: Qwen3 {NUM_LAYERS}L, {NUM_HEADS}H/{NUM_KV_HEADS}KV, dim={HIDDEN_SIZE}") |
| print(f"Max seq len: {MAX_SEQ_LEN}") |
| print(f"KV cache size per layer: {NUM_KV_HEADS}x{MAX_SEQ_LEN}x{HEAD_DIM} = {NUM_KV_HEADS*MAX_SEQ_LEN*HEAD_DIM/1e6:.2f}M elements") |
| print("=" * 60) |
|
|
| |
| original_model = load_original_model() |
|
|
| |
| decoder = build_decoder_module(original_model) |
|
|
| |
| test_decoder_module(decoder, original_model) |
|
|
| |
| del original_model |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| |
| exported = try_torch_export(decoder) |
|
|
| if exported is not None: |
| export_to_pte(exported) |
|
|
| |
| torch.save(decoder.state_dict(), "text_decoder_fixed.pt") |
| print(f"\nSaved fixed decoder state dict to text_decoder_fixed.pt") |
| print("Decoder export script complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|