LightOnOCR-2-1B-ExecuTorch / scripts /export_decoder.py
acul3's picture
Upload scripts/export_decoder.py with huggingface_hub
77cf118 verified
#!/usr/bin/env python3
"""
Phase 3b: Text Decoder Export for ExecuTorch
Extracts language_model + lm_head into a standalone nn.Module
with static KV cache tensors for torch.export compatibility.
Architecture: Qwen3 decoder (28 layers, GQA 16/8 heads, head_dim=128)
Fixed max_seq_len: 512
"""
import os
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# Model constants from config
HIDDEN_SIZE = 1024
NUM_LAYERS = 28
NUM_HEADS = 16
NUM_KV_HEADS = 8
HEAD_DIM = 128
INTERMEDIATE_SIZE = 3072
VOCAB_SIZE = 151936
MAX_SEQ_LEN = 4096
RMS_EPS = 1e-6
ROPE_THETA = 1000000.0
NUM_KV_GROUPS = NUM_HEADS // NUM_KV_HEADS # 2
MODEL_DIR = "./models/LightOnOCR-2-1B"
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = RMS_EPS) -> torch.Tensor:
"""Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator."""
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
return weight * x.to(input_dtype)
def precompute_rope_freqs(max_seq_len: int, head_dim: int, theta: float = ROPE_THETA):
"""Precompute RoPE cos/sin for all positions up to max_seq_len."""
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
cos = freqs.cos()
sin = freqs.sin()
# Duplicate for full head_dim: [seq_len, head_dim/2] -> [seq_len, head_dim]
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
return cos, sin # [max_seq_len, head_dim]
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""
Apply rotary position embeddings to query and key states.
q, k: [batch, num_heads, seq_len, head_dim]
cos, sin: [max_seq_len, head_dim]
position_ids: [batch, seq_len]
"""
# Gather cos/sin for the given positions
cos = cos[position_ids].unsqueeze(1) # [batch, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1) # [batch, 1, seq_len, head_dim]
# Rotate
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class Qwen3AttentionFixed(nn.Module):
"""
Fixed Qwen3 attention with static KV cache, inline QK-norm, and
no dynamic dispatch. Designed for torch.export compatibility.
"""
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.scaling = HEAD_DIM ** -0.5
# Projections
self.q_proj = nn.Linear(HIDDEN_SIZE, NUM_HEADS * HEAD_DIM, bias=False)
self.k_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False)
self.v_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False)
self.o_proj = nn.Linear(NUM_HEADS * HEAD_DIM, HIDDEN_SIZE, bias=False)
# QK-norm weights (RMSNorm per head)
self.q_norm_weight = nn.Parameter(torch.ones(HEAD_DIM))
self.k_norm_weight = nn.Parameter(torch.ones(HEAD_DIM))
def forward(
self,
hidden_states: torch.Tensor, # [batch, seq_len, hidden_size]
cos: torch.Tensor, # [max_seq_len, head_dim]
sin: torch.Tensor, # [max_seq_len, head_dim]
position_ids: torch.Tensor, # [batch, seq_len]
attention_mask: torch.Tensor, # [batch, 1, seq_len, cache_len+seq_len]
k_cache: torch.Tensor, # [batch, num_kv_heads, max_seq_len, head_dim]
v_cache: torch.Tensor, # [batch, num_kv_heads, max_seq_len, head_dim]
cache_position: torch.Tensor, # [seq_len] — positions to write into cache
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (output, updated_k_cache, updated_v_cache)"""
batch, seq_len, _ = hidden_states.shape
# Project Q, K, V
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape: [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim]
q = q.view(batch, seq_len, NUM_HEADS, HEAD_DIM)
k = k.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM)
v = v.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM)
# Apply QK-norm (RMSNorm per head, inline)
q = rms_norm(q, self.q_norm_weight)
k = rms_norm(k, self.k_norm_weight)
q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
# Apply RoPE
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
# Update KV cache using scatter (index_put)
# cache_position: [seq_len] — the positions to update
# k_cache shape: [batch, num_kv_heads, max_seq_len, head_dim]
k_cache = k_cache.clone()
v_cache = v_cache.clone()
k_cache[:, :, cache_position, :] = k
v_cache[:, :, cache_position, :] = v
# Expand KV heads for GQA: repeat each KV head for its group of Q heads
cache_len = k_cache.shape[2] # dynamic, works for any MAX_SEQ_LEN
k_expanded = k_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1)
k_expanded = k_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM)
v_expanded = v_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1)
v_expanded = v_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM)
# Attention: Q @ K^T / sqrt(head_dim)
attn_weights = torch.matmul(q, k_expanded.transpose(2, 3)) * self.scaling
# Apply attention mask
attn_weights = attn_weights + attention_mask
# Softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# Attention output
attn_output = torch.matmul(attn_weights, v_expanded)
# Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, hidden_size]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch, seq_len, -1)
# Output projection
attn_output = self.o_proj(attn_output)
return attn_output, k_cache, v_cache
class Qwen3MLPFixed(nn.Module):
"""Fixed Qwen3 MLP (SiLU gate + up projection)."""
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class Qwen3DecoderLayerFixed(nn.Module):
"""Fixed Qwen3 decoder layer with static KV cache."""
def __init__(self, layer_idx: int):
super().__init__()
self.self_attn = Qwen3AttentionFixed(layer_idx)
self.mlp = Qwen3MLPFixed()
self.input_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE))
self.post_attention_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE))
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_position: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Pre-norm + self attention
residual = hidden_states
hidden_states = rms_norm(hidden_states, self.input_layernorm_weight)
hidden_states, k_cache, v_cache = self.self_attn(
hidden_states, cos, sin, position_ids, attention_mask,
k_cache, v_cache, cache_position
)
hidden_states = residual + hidden_states
# Pre-norm + MLP
residual = hidden_states
hidden_states = rms_norm(hidden_states, self.post_attention_layernorm_weight)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, k_cache, v_cache
class TextDecoderFixed(nn.Module):
"""
Complete text decoder for ExecuTorch export.
Includes embedding, all decoder layers with static KV cache, and LM head.
For prefill: input_ids has seq_len > 1, cache_position starts at 0
For decode: input_ids has seq_len = 1, cache_position = current position
"""
def __init__(self):
super().__init__()
self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
self.layers = nn.ModuleList([
Qwen3DecoderLayerFixed(i) for i in range(NUM_LAYERS)
])
self.norm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE))
self.lm_head = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE, bias=False)
# Pre-compute RoPE frequencies
cos, sin = precompute_rope_freqs(MAX_SEQ_LEN, HEAD_DIM, ROPE_THETA)
self.register_buffer("rope_cos", cos)
self.register_buffer("rope_sin", sin)
def forward(
self,
input_ids: torch.Tensor, # [batch, seq_len]
attention_mask: torch.Tensor, # [batch, 1, seq_len, max_seq_len]
position_ids: torch.Tensor, # [batch, seq_len]
cache_position: torch.Tensor, # [seq_len]
*kv_caches: torch.Tensor, # 28 * (k_cache, v_cache) flattened
) -> tuple:
"""
Returns: (logits, *updated_kv_caches)
kv_caches: 56 tensors total (28 layers * 2 for k,v)
Each cache: [batch, num_kv_heads, max_seq_len, head_dim]
"""
# Embed tokens
hidden_states = self.embed_tokens(input_ids)
# Process through all layers, updating KV caches
updated_caches = []
for i, layer in enumerate(self.layers):
k_cache = kv_caches[i * 2]
v_cache = kv_caches[i * 2 + 1]
hidden_states, new_k, new_v = layer(
hidden_states,
self.rope_cos, self.rope_sin,
position_ids, attention_mask,
k_cache, v_cache, cache_position
)
updated_caches.append(new_k)
updated_caches.append(new_v)
# Final norm
hidden_states = rms_norm(hidden_states, self.norm_weight)
# LM head — only compute logits for the last token
logits = self.lm_head(hidden_states[:, -1:, :]) # [batch, 1, vocab_size]
return (logits, *updated_caches)
def load_original_model():
"""Load the original model with proper weight remapping."""
from transformers import AutoModelForImageTextToText
from safetensors.torch import load_file
print("Loading original model...")
model = AutoModelForImageTextToText.from_pretrained(
MODEL_DIR,
dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="cpu",
)
state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
remapped = {}
for k, v in state_dict.items():
new_k = k.replace("model.vision_encoder.", "model.vision_tower.")
new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.")
remapped[new_k] = v
model.load_state_dict(remapped, strict=False)
return model
def build_decoder_module(original_model):
"""Build the fixed decoder module from the original model's weights."""
print("\nBuilding fixed text decoder...")
orig_lm = original_model.model.language_model
orig_lm_head = original_model.lm_head
decoder = TextDecoderFixed()
# Copy embedding weights
decoder.embed_tokens.weight.data.copy_(orig_lm.embed_tokens.weight.data)
# Copy final norm weight
decoder.norm_weight.data.copy_(orig_lm.norm.weight.data)
# Copy LM head (tied with embeddings)
decoder.lm_head.weight.data.copy_(orig_lm.embed_tokens.weight.data)
# Copy layer weights
for i in range(NUM_LAYERS):
orig_layer = orig_lm.layers[i]
fixed_layer = decoder.layers[i]
# Attention projections
fixed_layer.self_attn.q_proj.weight.data.copy_(orig_layer.self_attn.q_proj.weight.data)
fixed_layer.self_attn.k_proj.weight.data.copy_(orig_layer.self_attn.k_proj.weight.data)
fixed_layer.self_attn.v_proj.weight.data.copy_(orig_layer.self_attn.v_proj.weight.data)
fixed_layer.self_attn.o_proj.weight.data.copy_(orig_layer.self_attn.o_proj.weight.data)
# QK-norm weights
fixed_layer.self_attn.q_norm_weight.data.copy_(orig_layer.self_attn.q_norm.weight.data)
fixed_layer.self_attn.k_norm_weight.data.copy_(orig_layer.self_attn.k_norm.weight.data)
# Layer norms
fixed_layer.input_layernorm_weight.data.copy_(orig_layer.input_layernorm.weight.data)
fixed_layer.post_attention_layernorm_weight.data.copy_(orig_layer.post_attention_layernorm.weight.data)
# MLP
fixed_layer.mlp.gate_proj.weight.data.copy_(orig_layer.mlp.gate_proj.weight.data)
fixed_layer.mlp.up_proj.weight.data.copy_(orig_layer.mlp.up_proj.weight.data)
fixed_layer.mlp.down_proj.weight.data.copy_(orig_layer.mlp.down_proj.weight.data)
decoder.eval()
total_params = sum(p.numel() for p in decoder.parameters())
print(f" Decoder parameters: {total_params/1e6:.2f}M")
return decoder
def create_empty_kv_caches(batch_size: int = 1, dtype=torch.float32, device="cpu"):
"""Create empty KV cache tensors for all layers."""
caches = []
for _ in range(NUM_LAYERS):
k = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
v = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
caches.extend([k, v])
return tuple(caches)
def create_causal_mask(seq_len: int, cache_len: int = MAX_SEQ_LEN, dtype=torch.float32):
"""Create causal attention mask."""
mask = torch.full((seq_len, cache_len), float("-inf"), dtype=dtype)
mask = torch.triu(mask, diagonal=cache_len - seq_len + 1)
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, cache_len]
def test_decoder_module(decoder, original_model):
"""Test that the fixed decoder produces same output as original."""
print("\nTesting decoder output consistency...")
device = "cuda" if torch.cuda.is_available() else "cpu"
decoder = decoder.to(device).to(torch.bfloat16)
original_model = original_model.to(device)
# Test input
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=device)
seq_len = input_ids.shape[1]
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
cache_position = torch.arange(seq_len, device=device)
# Causal mask
mask = create_causal_mask(seq_len, dtype=torch.bfloat16).to(device)
# Empty KV caches
kv_caches = create_empty_kv_caches(1, torch.bfloat16, device)
with torch.no_grad():
# Fixed decoder
result = decoder(input_ids, mask, position_ids, cache_position, *kv_caches)
fixed_logits = result[0]
print(f" Fixed decoder output shape: {fixed_logits.shape}")
# Original model (text-only, no image)
orig_outputs = original_model(
input_ids=input_ids,
attention_mask=torch.ones_like(input_ids),
use_cache=False,
)
orig_logits = orig_outputs.logits[:, -1:, :]
print(f" Original model output shape: {orig_logits.shape}")
# Compare
diff = (fixed_logits.float() - orig_logits.float()).abs()
print(f" Max absolute difference: {diff.max().item():.6f}")
print(f" Mean absolute difference: {diff.mean().item():.6f}")
# Check top-k predictions match
fixed_topk = fixed_logits.float().topk(5, dim=-1)
orig_topk = orig_logits.float().topk(5, dim=-1)
print(f" Fixed top-5 token IDs: {fixed_topk.indices[0, 0].tolist()}")
print(f" Original top-5 token IDs: {orig_topk.indices[0, 0].tolist()}")
matching = sum(1 for t in fixed_topk.indices[0, 0].tolist() if t in orig_topk.indices[0, 0].tolist())
print(f" Top-5 overlap: {matching}/5")
def try_torch_export(decoder):
"""Attempt torch.export.export() on the decoder."""
print("\n" + "=" * 60)
print("ATTEMPTING torch.export.export() on decoder")
print("=" * 60)
# Export on CPU with float32 for XNNPACK
decoder = decoder.to("cpu").to(torch.float32)
decoder.eval()
batch_size = 1
seq_len = 1 # Export for single-token decode step (simpler)
input_ids = torch.randint(0, VOCAB_SIZE, (batch_size, seq_len))
attention_mask = create_causal_mask(seq_len, MAX_SEQ_LEN, torch.float32)
position_ids = torch.zeros(batch_size, seq_len, dtype=torch.long)
cache_position = torch.zeros(seq_len, dtype=torch.long)
kv_caches = create_empty_kv_caches(batch_size, torch.float32, "cpu")
example_args = (input_ids, attention_mask, position_ids, cache_position, *kv_caches)
try:
print(f" Exporting with seq_len={seq_len}, max_cache={MAX_SEQ_LEN}...")
print(f" Number of input tensors: {len(example_args)} (4 + {NUM_LAYERS}*2 KV caches)")
exported = torch.export.export(
decoder,
example_args,
strict=False,
)
print(" SUCCESS! torch.export completed!")
return exported
except Exception as e:
print(f" FAILED: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
# Try with trace as fallback
print("\n Trying torch.jit.trace as fallback...")
try:
traced = torch.jit.trace(decoder, example_args)
print(" torch.jit.trace succeeded!")
return traced
except Exception as e2:
print(f" torch.jit.trace also failed: {type(e2).__name__}: {e2}")
return None
def export_to_pte(exported_model):
"""Convert exported model to .pte using XNNPACK backend."""
print("\n" + "=" * 60)
print("EXPORTING DECODER TO .pte (XNNPACK)")
print("=" * 60)
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
if not hasattr(exported_model, 'graph_module'):
print(" Need torch.export.export() result for .pte export")
return None
print(" Running to_edge_transform_and_lower...")
edge = to_edge_transform_and_lower(
exported_model,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
print(" Running to_executorch()...")
pte = edge.to_executorch()
output_path = "text_decoder.pte"
with open(output_path, "wb") as f:
f.write(pte.buffer)
file_size = os.path.getsize(output_path) / (1024 * 1024)
print(f" Saved to {output_path} ({file_size:.1f} MB)")
return output_path
except ImportError as e:
print(f" ExecuTorch import failed: {e}")
return None
except Exception as e:
print(f" Export failed: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return None
def main():
print("=" * 60)
print("Text Decoder Export for ExecuTorch")
print(f"Architecture: Qwen3 {NUM_LAYERS}L, {NUM_HEADS}H/{NUM_KV_HEADS}KV, dim={HIDDEN_SIZE}")
print(f"Max seq len: {MAX_SEQ_LEN}")
print(f"KV cache size per layer: {NUM_KV_HEADS}x{MAX_SEQ_LEN}x{HEAD_DIM} = {NUM_KV_HEADS*MAX_SEQ_LEN*HEAD_DIM/1e6:.2f}M elements")
print("=" * 60)
# Load original model
original_model = load_original_model()
# Build fixed decoder
decoder = build_decoder_module(original_model)
# Test consistency
test_decoder_module(decoder, original_model)
# Free original model memory
del original_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Try torch.export
exported = try_torch_export(decoder)
if exported is not None:
export_to_pte(exported)
# Save the PyTorch module for later use
torch.save(decoder.state_dict(), "text_decoder_fixed.pt")
print(f"\nSaved fixed decoder state dict to text_decoder_fixed.pt")
print("Decoder export script complete!")
if __name__ == "__main__":
main()