surya-ocr-2-coreml-runtime / scripts /export_surya_coreml_runtime.py
Reza2kn's picture
Upload Surya OCR 2 CoreML runtime canary
92c0d8d verified
Raw
History Blame Contribute Delete
60.6 kB
#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
import coremltools as ct
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
from qwen_vl_utils import process_vision_info
from transformers import AutoConfig, AutoProcessor, Qwen3_5ForConditionalGeneration
PROMPT = "OCR this image to HTML."
def load_model(model_id: str, dtype: torch.dtype):
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
config._attn_implementation = "eager"
if hasattr(config, "text_config"):
config.text_config._attn_implementation = "eager"
if hasattr(config, "vision_config"):
config.vision_config._attn_implementation = "eager"
model = Qwen3_5ForConditionalGeneration.from_pretrained(
model_id,
config=config,
torch_dtype=dtype,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="eager",
).eval()
model.config._attn_implementation = "eager"
if hasattr(model.config, "text_config"):
model.config.text_config._attn_implementation = "eager"
return model
def build_sample(processor):
image = Image.new("RGB", (512, 512), "white")
draw = ImageDraw.Draw(image)
draw.text((40, 80), "Invoice 123\nTotal $42.00", fill="black")
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": PROMPT}]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt")
class SuryaCoreMLDecodeStep(torch.nn.Module):
def __init__(self, language_model, lm_head):
super().__init__()
self.language_model = language_model
self.lm_head = lm_head
@staticmethod
def rotate_half_static(x, rotary_dim):
half_dim = rotary_dim // 2
x1 = x.narrow(-1, 0, half_dim)
x2 = x.narrow(-1, half_dim, half_dim)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_static(self, q, k, cos, sin):
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
rotary_dim = int(cos.shape[-1])
q_rot = q.narrow(-1, 0, rotary_dim)
q_pass = q.narrow(-1, rotary_dim, int(q.shape[-1]) - rotary_dim)
k_rot = k.narrow(-1, 0, rotary_dim)
k_pass = k.narrow(-1, rotary_dim, int(k.shape[-1]) - rotary_dim)
q_embed = (q_rot * cos) + (self.rotate_half_static(q_rot, rotary_dim) * sin)
k_embed = (k_rot * cos) + (self.rotate_half_static(k_rot, rotary_dim) * sin)
return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1)
@staticmethod
def repeat_kv_static(hidden_states, n_rep):
if n_rep == 1:
return hidden_states
batch = int(hidden_states.shape[0])
num_key_value_heads = int(hidden_states.shape[1])
seq_len = int(hidden_states.shape[2])
head_dim = int(hidden_states.shape[3])
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
def full_attention_decode(self, attn, hidden_states, cos, sin, past_key, past_value):
query_states, gate = torch.chunk(attn.q_proj(hidden_states).view(1, 1, -1, attn.head_dim * 2), 2, dim=-1)
gate = gate.reshape(1, 1, -1)
query_states = attn.q_norm(query_states.view(1, 1, -1, attn.head_dim)).transpose(1, 2)
key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, 1, -1, attn.head_dim)).transpose(1, 2)
value_states = attn.v_proj(hidden_states).view(1, 1, -1, attn.head_dim).transpose(1, 2)
query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin)
all_key = torch.cat((past_key, key_states), dim=2)
all_value = torch.cat((past_value, value_states), dim=2)
key_for_attn = self.repeat_kv_static(all_key, attn.num_key_value_groups)
value_for_attn = self.repeat_kv_static(all_value, attn.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_for_attn)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, -1)
attn_output = attn_output * torch.sigmoid(gate)
return attn.o_proj(attn_output), all_key, all_value
def full_attention_decode_fixed(self, attn, hidden_states, cos, sin, past_key, past_value, attention_mask):
query_states, gate = torch.chunk(attn.q_proj(hidden_states).view(1, 1, -1, attn.head_dim * 2), 2, dim=-1)
gate = gate.reshape(1, 1, -1)
query_states = attn.q_norm(query_states.view(1, 1, -1, attn.head_dim)).transpose(1, 2)
key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, 1, -1, attn.head_dim)).transpose(1, 2)
value_states = attn.v_proj(hidden_states).view(1, 1, -1, attn.head_dim).transpose(1, 2)
query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin)
all_key = torch.cat((past_key, key_states), dim=2)
all_value = torch.cat((past_value, value_states), dim=2)
key_for_attn = self.repeat_kv_static(all_key, attn.num_key_value_groups)
value_for_attn = self.repeat_kv_static(all_value, attn.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling
attn_weights = attn_weights + attention_mask.to(attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_for_attn)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, -1)
attn_output = attn_output * torch.sigmoid(gate)
return attn.o_proj(attn_output), key_states, value_states
@staticmethod
def l2norm_static(x):
return x * torch.rsqrt((x * x).sum(dim=-1, keepdim=True) + 1e-6)
def recurrent_gated_delta_decode(self, query, key, value, g, beta, recurrent_state):
initial_dtype = query.dtype
query = self.l2norm_static(query)
key = self.l2norm_static(key)
query = query.transpose(1, 2).contiguous().float() * (1 / (query.shape[-1] ** 0.5))
key = key.transpose(1, 2).contiguous().float()
value = value.transpose(1, 2).contiguous().float()
beta = beta.transpose(1, 2).contiguous().float()
g = g.transpose(1, 2).contiguous().float()
q_t = query[:, :, 0]
k_t = key[:, :, 0]
v_t = value[:, :, 0]
g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, 0].unsqueeze(-1)
recurrent_state = recurrent_state.float() * g_t
kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
recurrent_state = recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_attn_out = (recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
return core_attn_out.unsqueeze(2).transpose(1, 2).contiguous().to(initial_dtype), recurrent_state
def gated_delta_decode(self, linear_attn, hidden_states, conv_state, recurrent_state):
mixed_qkv_raw = linear_attn.in_proj_qkv(hidden_states).transpose(1, 2)
conv_input = torch.cat((conv_state, mixed_qkv_raw), dim=-1)
mixed_qkv = F.conv1d(
conv_input.to(linear_attn.conv1d.weight.dtype),
linear_attn.conv1d.weight,
linear_attn.conv1d.bias,
padding=0,
groups=int(conv_input.shape[1]),
)
mixed_qkv = F.silu(mixed_qkv[:, :, -1:]).to(mixed_qkv_raw.dtype)
new_conv_state = conv_input[:, :, -linear_attn.conv_kernel_size :]
mixed_qkv = mixed_qkv.transpose(1, 2)
z = linear_attn.in_proj_z(hidden_states).reshape(1, 1, -1, linear_attn.head_v_dim)
b = linear_attn.in_proj_b(hidden_states)
a = linear_attn.in_proj_a(hidden_states)
query, key, value = torch.split(
mixed_qkv,
[linear_attn.key_dim, linear_attn.key_dim, linear_attn.value_dim],
dim=-1,
)
query = query.reshape(1, 1, -1, linear_attn.head_k_dim)
key = key.reshape(1, 1, -1, linear_attn.head_k_dim)
value = value.reshape(1, 1, -1, linear_attn.head_v_dim)
beta = b.sigmoid()
g = -linear_attn.A_log.float().exp() * F.softplus(a.float() + linear_attn.dt_bias)
core_attn_out, new_recurrent_state = self.recurrent_gated_delta_decode(query, key, value, g, beta, recurrent_state)
core_attn_out = core_attn_out.reshape(-1, linear_attn.head_v_dim)
z = z.reshape(-1, linear_attn.head_v_dim)
normed = linear_attn.norm(core_attn_out, z)
normed = normed.reshape(1, 1, -1)
return linear_attn.out_proj(normed), new_conv_state, new_recurrent_state
def decode_lists(self, inputs_embeds, cos, sin, full_keys, full_values, conv_states, recurrent_states):
hidden_states = inputs_embeds
next_full_keys = []
next_full_values = []
next_conv_states = []
next_recurrent_states = []
full_idx = 0
linear_idx = 0
for layer in self.language_model.layers:
residual = hidden_states
hidden_states = layer.input_layernorm(hidden_states)
if layer.layer_type == "linear_attention":
hidden_states, new_conv, new_recurrent = self.gated_delta_decode(
layer.linear_attn,
hidden_states,
conv_states[linear_idx],
recurrent_states[linear_idx],
)
next_conv_states.append(new_conv)
next_recurrent_states.append(new_recurrent)
linear_idx += 1
else:
hidden_states, new_key, new_value = self.full_attention_decode(
layer.self_attn,
hidden_states,
cos,
sin,
full_keys[full_idx],
full_values[full_idx],
)
next_full_keys.append(new_key)
next_full_values.append(new_value)
full_idx += 1
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.language_model.norm(hidden_states)
logits = self.lm_head(hidden_states[:, -1:, :])
return logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states
class SuryaCoreMLDecodeStepFlat(SuryaCoreMLDecodeStep):
full_layers = 6
linear_layers = 18
def forward(self, inputs_embeds, cos, sin, attention_mask, *states):
full_keys = list(states[: self.full_layers])
full_values = list(states[self.full_layers : self.full_layers * 2])
linear_offset = self.full_layers * 2
conv_states = list(states[linear_offset : linear_offset + self.linear_layers])
recurrent_states = list(states[linear_offset + self.linear_layers : linear_offset + self.linear_layers * 2])
logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states = self.decode_lists_fixed(
inputs_embeds,
cos,
sin,
full_keys,
full_values,
attention_mask,
conv_states,
recurrent_states,
)
return tuple([logits] + next_full_keys + next_full_values + next_conv_states + next_recurrent_states)
def decode_lists_fixed(self, inputs_embeds, cos, sin, full_keys, full_values, attention_mask, conv_states, recurrent_states):
hidden_states = inputs_embeds
next_full_keys = []
next_full_values = []
next_conv_states = []
next_recurrent_states = []
full_idx = 0
linear_idx = 0
for layer in self.language_model.layers:
residual = hidden_states
hidden_states = layer.input_layernorm(hidden_states)
if layer.layer_type == "linear_attention":
hidden_states, new_conv, new_recurrent = self.gated_delta_decode(
layer.linear_attn,
hidden_states,
conv_states[linear_idx],
recurrent_states[linear_idx],
)
next_conv_states.append(new_conv)
next_recurrent_states.append(new_recurrent)
linear_idx += 1
else:
hidden_states, new_key, new_value = self.full_attention_decode_fixed(
layer.self_attn,
hidden_states,
cos,
sin,
full_keys[full_idx],
full_values[full_idx],
attention_mask,
)
next_full_keys.append(new_key)
next_full_values.append(new_value)
full_idx += 1
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.language_model.norm(hidden_states)
logits = self.lm_head(hidden_states[:, -1:, :])
return logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states
class SuryaCoreMLPrefillFlat(SuryaCoreMLDecodeStep):
full_layers = 6
linear_layers = 18
def __init__(self, language_model, lm_head, seq_len, max_cache_length):
super().__init__(language_model, lm_head)
self.seq_len = seq_len
self.max_cache_length = max_cache_length
mask = torch.full((1, 1, seq_len, seq_len), torch.finfo(torch.float32).min)
self.register_buffer("causal_mask", torch.triu(mask, diagonal=1))
def full_attention_prefill(self, attn, hidden_states, cos, sin):
query_states, gate = torch.chunk(
attn.q_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim * 2),
2,
dim=-1,
)
gate = gate.reshape(1, self.seq_len, -1)
query_states = attn.q_norm(query_states.view(1, self.seq_len, -1, attn.head_dim)).transpose(1, 2)
key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim)).transpose(1, 2)
value_states = attn.v_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim).transpose(1, 2)
query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin)
key_for_attn = self.repeat_kv_static(key_states, attn.num_key_value_groups)
value_for_attn = self.repeat_kv_static(value_states, attn.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling
attn_weights = attn_weights + self.causal_mask.to(attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_for_attn)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, self.seq_len, -1)
attn_output = attn_output * torch.sigmoid(gate)
key_pad = torch.zeros(
1,
int(key_states.shape[1]),
self.max_cache_length,
int(key_states.shape[3]),
dtype=key_states.dtype,
device=key_states.device,
)
value_pad = torch.zeros(
1,
int(value_states.shape[1]),
self.max_cache_length,
int(value_states.shape[3]),
dtype=value_states.dtype,
device=value_states.device,
)
key_pad[:, :, : self.seq_len, :] = key_states
value_pad[:, :, : self.seq_len, :] = value_states
return attn.o_proj(attn_output), key_pad, value_pad
def chunk_gated_delta_rule_prefill(self, query, key, value, g, beta):
chunk_size = 64
sequence_length = self.seq_len
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
total_sequence_length = sequence_length + pad_size
num_chunks = total_sequence_length // chunk_size
initial_dtype = query.dtype
query = self.l2norm_static(query)
key = self.l2norm_static(key)
query = query.transpose(1, 2).contiguous().float()
key = key.transpose(1, 2).contiguous().float()
value = value.transpose(1, 2).contiguous().float()
beta = beta.transpose(1, 2).contiguous().float()
g = g.transpose(1, 2).contiguous().float()
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
query = query * (1 / (128**0.5))
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query = query.reshape(1, 16, num_chunks, chunk_size, 128)
key = key.reshape(1, 16, num_chunks, chunk_size, 128)
value = value.reshape(1, 16, num_chunks, chunk_size, 128)
k_beta = k_beta.reshape(1, 16, num_chunks, chunk_size, 128)
v_beta = v_beta.reshape(1, 16, num_chunks, chunk_size, 128)
g = g.reshape(1, 16, num_chunks, chunk_size)
tri0 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
tri1 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
g_cum = g.cumsum(dim=-1)
decay_mask = ((g_cum.unsqueeze(-1) - g_cum.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(tri0, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
update = row + (row.unsqueeze(-1) * sub).sum(-2)
new_row = torch.cat((update, attn[..., i, i:]), dim=-1)
attn = torch.cat((attn[..., :i, :], new_row.unsqueeze(-2), attn[..., i + 1 :, :]), dim=-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=query.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g_cum.exp().unsqueeze(-1))
last_recurrent_state = torch.zeros(1, 16, 128, 128, dtype=value.dtype, device=value.device)
outs = []
for i in range(num_chunks):
q_i = query[:, :, i]
k_i = key[:, :, i]
v_i = value[:, :, i]
attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill(tri1, 0)
v_prime = k_cumdecay[:, :, i] @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g_cum[:, :, i, :, None].exp()) @ last_recurrent_state
outs.append(attn_inter + attn_i @ v_new)
last_recurrent_state = (
last_recurrent_state * g_cum[:, :, i, -1, None, None].exp()
+ (k_i * (g_cum[:, :, i, -1, None] - g_cum[:, :, i]).exp()[..., None]).transpose(-1, -2)
@ v_new
)
core_attn_out = torch.stack(outs, dim=2).reshape(1, 16, total_sequence_length, 128)
core_attn_out = core_attn_out[:, :, :sequence_length].transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def gated_delta_prefill(self, linear_attn, hidden_states):
mixed_qkv_raw = linear_attn.in_proj_qkv(hidden_states).transpose(1, 2)
z = linear_attn.in_proj_z(hidden_states).reshape(1, self.seq_len, -1, linear_attn.head_v_dim)
b = linear_attn.in_proj_b(hidden_states)
a = linear_attn.in_proj_a(hidden_states)
mixed_qkv = F.silu(linear_attn.conv1d(mixed_qkv_raw)[:, :, : self.seq_len]).transpose(1, 2)
query, key, value = torch.split(mixed_qkv, [linear_attn.key_dim, linear_attn.key_dim, linear_attn.value_dim], dim=-1)
query = query.reshape(1, self.seq_len, -1, linear_attn.head_k_dim)
key = key.reshape(1, self.seq_len, -1, linear_attn.head_k_dim)
value = value.reshape(1, self.seq_len, -1, linear_attn.head_v_dim)
beta = b.sigmoid()
g = -linear_attn.A_log.float().exp() * F.softplus(a.float() + linear_attn.dt_bias)
core_attn_out, recurrent_state = self.chunk_gated_delta_rule_prefill(query, key, value, g, beta)
core_attn_out = core_attn_out.reshape(-1, linear_attn.head_v_dim)
z = z.reshape(-1, linear_attn.head_v_dim)
normed = linear_attn.norm(core_attn_out, z).reshape(1, self.seq_len, -1)
conv_state = mixed_qkv_raw[:, :, -linear_attn.conv_kernel_size :]
return linear_attn.out_proj(normed), conv_state, recurrent_state
def forward(self, inputs_embeds, cos, sin):
hidden_states = inputs_embeds
full_keys = []
full_values = []
conv_states = []
recurrent_states = []
for layer in self.language_model.layers:
residual = hidden_states
hidden_states = layer.input_layernorm(hidden_states)
if layer.layer_type == "linear_attention":
hidden_states, conv_state, recurrent_state = self.gated_delta_prefill(layer.linear_attn, hidden_states)
conv_states.append(conv_state)
recurrent_states.append(recurrent_state)
else:
hidden_states, key_state, value_state = self.full_attention_prefill(layer.self_attn, hidden_states, cos, sin)
full_keys.append(key_state)
full_values.append(value_state)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = layer.post_attention_layernorm(hidden_states)
hidden_states = layer.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.language_model.norm(hidden_states)
logits = self.lm_head(hidden_states[:, -1:, :])
return tuple([logits] + full_keys + full_values + conv_states + recurrent_states)
def extract_cache_lists(cache):
full_keys = []
full_values = []
conv_states = []
recurrent_states = []
for layer in cache.layers:
if hasattr(layer, "keys"):
full_keys.append(layer.keys.detach().clone())
full_values.append(layer.values.detach().clone())
else:
conv_states.append(layer.conv_states.detach().clone())
recurrent_states.append(layer.recurrent_states.detach().clone())
return full_keys, full_values, conv_states, recurrent_states
def pad_full_caches(full_keys, full_values, max_cache_length):
padded_keys = []
padded_values = []
for key, value in zip(full_keys, full_values, strict=True):
if key.shape[2] > max_cache_length:
raise ValueError(f"full attention cache length {key.shape[2]} exceeds max_cache_length={max_cache_length}")
key_pad = torch.zeros(key.shape[0], key.shape[1], max_cache_length, key.shape[3], dtype=key.dtype)
value_pad = torch.zeros(value.shape[0], value.shape[1], max_cache_length, value.shape[3], dtype=value.dtype)
key_pad[:, :, : key.shape[2], :] = key
value_pad[:, :, : value.shape[2], :] = value
padded_keys.append(key_pad)
padded_values.append(value_pad)
return padded_keys, padded_values
def decode_attention_mask(cache_length, max_cache_length, dtype=torch.float32):
mask = torch.full((1, 1, 1, max_cache_length + 1), torch.finfo(dtype).min, dtype=dtype)
mask[:, :, :, :cache_length] = 0
mask[:, :, :, -1:] = 0
return mask
def parity_check(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample = build_sample(processor)
with torch.no_grad():
prefill = model(**sample, use_cache=True, return_dict=True)
first_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values)
with torch.no_grad():
native = model(input_ids=first_token, past_key_values=prefill.past_key_values, use_cache=True, return_dict=True)
step = SuryaCoreMLDecodeStep(model.model.language_model, model.lm_head).eval()
inputs_embeds = model.model.language_model.embed_tokens(first_token)
position_ids = torch.full((3, 1, 1), int(sample["input_ids"].shape[1]), dtype=torch.long)
cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids)
with torch.no_grad():
logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states = step.decode_lists(
inputs_embeds,
cos,
sin,
full_keys,
full_values,
conv_states,
recurrent_states,
)
padded_full_keys, padded_full_values = pad_full_caches(full_keys, full_values, args.max_cache_length)
attention_mask = decode_attention_mask(int(sample["input_ids"].shape[1]), args.max_cache_length)
with torch.no_grad():
fixed_logits, fixed_next_keys, fixed_next_values, fixed_next_conv, fixed_next_recurrent = step.decode_lists_fixed(
inputs_embeds,
cos,
sin,
padded_full_keys,
padded_full_values,
attention_mask,
conv_states,
recurrent_states,
)
report = {
"prompt_tokens": int(sample["input_ids"].shape[1]),
"first_token": int(first_token.item()),
"native_next_argmax": int(native.logits[:, -1, :].argmax(dim=-1).item()),
"custom_next_argmax": int(logits[:, -1, :].argmax(dim=-1).item()),
"logits_max_abs_diff": float((native.logits[:, -1:, :] - logits).abs().max().item()),
"logits_mean_abs_diff": float((native.logits[:, -1:, :] - logits).abs().mean().item()),
"fixed_cache_length": args.max_cache_length,
"fixed_native_max_abs_diff": float((native.logits[:, -1:, :] - fixed_logits).abs().max().item()),
"fixed_native_mean_abs_diff": float((native.logits[:, -1:, :] - fixed_logits).abs().mean().item()),
"fixed_growing_max_abs_diff": float((logits - fixed_logits).abs().max().item()),
"fixed_custom_next_argmax": int(fixed_logits[:, -1, :].argmax(dim=-1).item()),
"full_layers": len(next_full_keys),
"linear_layers": len(next_conv_states),
"next_full_key_shape": list(next_full_keys[0].shape),
"next_conv_shape": list(next_conv_states[0].shape),
"next_recurrent_shape": list(next_recurrent_states[0].shape),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def build_decode_examples(model, processor, max_cache_length):
sample = build_sample(processor)
with torch.no_grad():
prefill = model(**sample, use_cache=True, return_dict=True)
first_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values)
padded_full_keys, padded_full_values = pad_full_caches(full_keys, full_values, max_cache_length)
inputs_embeds = model.model.language_model.embed_tokens(first_token)
position_ids = torch.full((3, 1, 1), int(sample["input_ids"].shape[1]), dtype=torch.long)
cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids)
attention_mask = decode_attention_mask(int(sample["input_ids"].shape[1]), max_cache_length)
state_examples = tuple(padded_full_keys + padded_full_values + conv_states + recurrent_states)
return sample, (inputs_embeds, cos, sin, attention_mask) + state_examples
def decode_input_specs(example):
names = ["inputs_embeds", "cos", "sin", "attention_mask"]
names += [f"full_key_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)]
names += [f"full_value_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)]
names += [f"conv_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
names += [f"recurrent_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
return [ct.TensorType(name=name, shape=tuple(t.shape), dtype=np.float32) for name, t in zip(names, example, strict=True)]
def decode_output_specs():
names = ["logits"]
names += [f"new_full_key_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)]
names += [f"new_full_value_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)]
names += [f"new_conv_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
names += [f"new_recurrent_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
return [ct.TensorType(name=name) for name in names]
def build_prefill_example(model, processor):
sample = build_sample(processor)
with torch.no_grad():
image_outputs = model.model.get_image_features(
sample["pixel_values"].to(next(model.parameters()).dtype),
sample["image_grid_thw"],
return_dict=True,
)
image_embeds = torch.cat(image_outputs.pooler_output, dim=0)
return build_prefill_example_from_image_embeds(model, sample, image_embeds)
def build_prefill_example_from_image_embeds(model, sample, image_embeds):
input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
mm_token_type_ids = sample["mm_token_type_ids"]
image_grid_thw = sample["image_grid_thw"]
with torch.no_grad():
inputs_embeds = model.model.get_input_embeddings()(input_ids)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = model.model.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
position_ids = model.model.compute_3d_position_ids(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=None,
mm_token_type_ids=mm_token_type_ids,
)
cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids)
return sample, (inputs_embeds, cos, sin)
def prefill_input_specs(example):
names = ["inputs_embeds", "cos", "sin"]
return [ct.TensorType(name=name, shape=tuple(t.shape), dtype=np.float32) for name, t in zip(names, example, strict=True)]
def prefill_output_specs():
names = ["logits"]
names += [f"full_key_{i}" for i in range(SuryaCoreMLPrefillFlat.full_layers)]
names += [f"full_value_{i}" for i in range(SuryaCoreMLPrefillFlat.full_layers)]
names += [f"conv_state_{i}" for i in range(SuryaCoreMLPrefillFlat.linear_layers)]
names += [f"recurrent_state_{i}" for i in range(SuryaCoreMLPrefillFlat.linear_layers)]
return [ct.TensorType(name=name) for name in names]
def check_prefill_parity(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample, example = build_prefill_example(model, processor)
wrapper = SuryaCoreMLPrefillFlat(
model.model.language_model,
model.lm_head,
int(example[0].shape[1]),
args.max_cache_length,
).eval()
with torch.no_grad():
native = model(**sample, use_cache=True, return_dict=True)
custom_outputs = wrapper(*example)
native_full_keys, native_full_values, native_conv, native_recurrent = extract_cache_lists(native.past_key_values)
native_full_keys, native_full_values = pad_full_caches(native_full_keys, native_full_values, args.max_cache_length)
custom_logits = custom_outputs[0]
custom_full_keys = list(custom_outputs[1:7])
custom_full_values = list(custom_outputs[7:13])
custom_conv = list(custom_outputs[13:31])
custom_recurrent = list(custom_outputs[31:49])
report = {
"prompt_tokens": int(sample["input_ids"].shape[1]),
"max_cache_length": args.max_cache_length,
"native_argmax": int(native.logits[:, -1, :].argmax(dim=-1).item()),
"custom_argmax": int(custom_logits[:, -1, :].argmax(dim=-1).item()),
"logits_max_abs_diff": float((native.logits[:, -1:, :] - custom_logits).abs().max().item()),
"logits_mean_abs_diff": float((native.logits[:, -1:, :] - custom_logits).abs().mean().item()),
"full_key0_max_abs_diff": float((native_full_keys[0] - custom_full_keys[0]).abs().max().item()),
"full_value0_max_abs_diff": float((native_full_values[0] - custom_full_values[0]).abs().max().item()),
"conv0_max_abs_diff": float((native_conv[0] - custom_conv[0]).abs().max().item()),
"recurrent0_max_abs_diff": float((native_recurrent[0] - custom_recurrent[0]).abs().max().item()),
"full_layers": len(custom_full_keys),
"linear_layers": len(custom_conv),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def export_prefill(args) -> None:
output_dir = args.output_dir.expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample, example = build_prefill_example(model, processor)
wrapper = SuryaCoreMLPrefillFlat(
model.model.language_model,
model.lm_head,
int(example[0].shape[1]),
args.max_cache_length,
).eval()
with torch.no_grad():
traced = torch.jit.trace(wrapper, example, strict=False, check_trace=False)
package_path = output_dir / f"surya_prefill_fp16_seq{int(example[0].shape[1])}_cache{args.max_cache_length}.mlpackage"
mlmodel = ct.convert(
traced,
convert_to="mlprogram",
minimum_deployment_target=ct.target.macOS14,
compute_precision=ct.precision.FLOAT16,
compute_units=ct.ComputeUnit.CPU_ONLY,
skip_model_load=True,
inputs=prefill_input_specs(example),
outputs=prefill_output_specs(),
)
mlmodel.save(str(package_path))
manifest = {
"model_id": args.model_id,
"mode": "prefill",
"source_dtype": "bf16",
"coreml_compute_precision": "fp16",
"seq_len": int(example[0].shape[1]),
"max_cache_length": args.max_cache_length,
"package": str(package_path),
"inputs": [spec.name for spec in prefill_input_specs(example)],
"outputs": [spec.name for spec in prefill_output_specs()],
}
(output_dir / f"surya_prefill_fp16_seq{int(example[0].shape[1])}_cache{args.max_cache_length}.json").write_text(
json.dumps(manifest, indent=2) + "\n",
encoding="utf-8",
)
print(json.dumps(manifest, indent=2), flush=True)
def smoke_prefill(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample, example = build_prefill_example(model, processor)
wrapper = SuryaCoreMLPrefillFlat(model.model.language_model, model.lm_head, int(example[0].shape[1]), args.max_cache_length).eval()
with torch.no_grad():
torch_outputs = wrapper(*example)
mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
specs = prefill_input_specs(example)
feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)}
coreml_outputs = mlmodel.predict(feed)
torch_logits = torch_outputs[0].detach().cpu().numpy()
logits = coreml_outputs["logits"]
report = {
"package": str(args.package.expanduser().resolve()),
"prompt_tokens": int(sample["input_ids"].shape[1]),
"torch_argmax": int(torch_logits[:, -1, :].argmax(axis=-1)[0]),
"coreml_argmax": int(logits[:, -1, :].argmax(axis=-1)[0]),
"logits_max_abs_diff": float(np.max(np.abs(torch_logits - logits))),
"logits_mean_abs_diff": float(np.mean(np.abs(torch_logits - logits))),
"outputs_seen": sorted(coreml_outputs.keys()),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def vision_combined_runtime_smoke(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample = build_sample(processor)
vision_ml = ct.models.MLModel(str(args.vision_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
prefill_ml = ct.models.MLModel(str(args.prefill_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
decode_ml = ct.models.MLModel(str(args.decode_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
pixel_values = sample["pixel_values"].detach().cpu().numpy().astype(np.float32)
vision_outputs = vision_ml.predict({"pixel_values": pixel_values})
image_embeds_np = vision_outputs["image_embeds"]
image_embeds = torch.from_numpy(image_embeds_np)
prefill_example_sample, prefill_example = build_prefill_example_from_image_embeds(model, sample, image_embeds)
with torch.no_grad():
torch_visual = model.model.visual(
sample["pixel_values"].to(next(model.parameters()).dtype),
sample["image_grid_thw"],
).pooler_output.detach().cpu().numpy()
prefill_specs = prefill_input_specs(prefill_example)
prefill_feed = {
spec.name: tensor.detach().cpu().numpy().astype(np.float32)
for spec, tensor in zip(prefill_specs, prefill_example, strict=True)
}
prefill_outputs = prefill_ml.predict(prefill_feed)
prefill_logits = prefill_outputs["logits"]
current_token = torch.tensor([[int(prefill_logits[:, -1, :].argmax(axis=-1)[0])]], dtype=torch.long)
full_keys = [torch.from_numpy(prefill_outputs[f"full_key_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)]
full_values = [torch.from_numpy(prefill_outputs[f"full_value_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)]
conv_states = [torch.from_numpy(prefill_outputs[f"conv_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)]
recurrent_states = [
torch.from_numpy(prefill_outputs[f"recurrent_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)
]
with torch.no_grad():
native_prefill = model(**sample, use_cache=True, return_dict=True)
native_token = native_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
native_cache = native_prefill.past_key_values
cache_len = int(prefill_example_sample["input_ids"].shape[1])
generated_coreml = [int(current_token.item())]
generated_native = [int(native_token.item())]
for _ in range(args.steps):
with torch.no_grad():
inputs_embeds = model.model.language_model.embed_tokens(current_token)
position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long)
cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids)
attention_mask = decode_attention_mask(cache_len, args.max_cache_length)
example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states)
specs = decode_input_specs(example)
feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)}
decode_outputs = decode_ml.predict(feed)
logits = decode_outputs["logits"]
next_token = int(logits[:, -1, :].argmax(axis=-1)[0])
if cache_len >= args.max_cache_length:
raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}")
for i in range(SuryaCoreMLDecodeStepFlat.full_layers):
full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_key_{i}"])
full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_value_{i}"])
conv_states = [torch.from_numpy(decode_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
recurrent_states = [
torch.from_numpy(decode_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)
]
cache_len += 1
current_token = torch.tensor([[next_token]], dtype=torch.long)
generated_coreml.append(next_token)
with torch.no_grad():
native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True)
native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
native_cache = native.past_key_values
generated_native.append(int(native_token.item()))
report = {
"vision_package": str(args.vision_package.expanduser().resolve()),
"prefill_package": str(args.prefill_package.expanduser().resolve()),
"decode_package": str(args.decode_package.expanduser().resolve()),
"prompt_tokens": int(prefill_example_sample["input_ids"].shape[1]),
"steps": args.steps,
"vision_shape": list(image_embeds_np.shape),
"vision_max_abs_diff_vs_torch": float(np.max(np.abs(torch_visual - image_embeds_np))),
"vision_mean_abs_diff_vs_torch": float(np.mean(np.abs(torch_visual - image_embeds_np))),
"prefill_coreml_first_token": generated_coreml[0],
"prefill_native_first_token": generated_native[0],
"coreml_tokens": generated_coreml,
"native_tokens": generated_native,
"prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b),
"all_tokens_match": generated_coreml == generated_native,
"coreml_text": processor.decode(generated_coreml, skip_special_tokens=False),
"native_text": processor.decode(generated_native, skip_special_tokens=False),
"host_responsibilities": [
"tokenization",
"initial text token embedding lookup",
"image placeholder insertion",
"rotary position embedding generation",
"generated-token embedding lookup",
"full-attention KV cache insertion",
],
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def export_decode_step(args) -> None:
output_dir = args.output_dir.expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample, example = build_decode_examples(model, processor, args.max_cache_length)
wrapper = SuryaCoreMLDecodeStepFlat(model.model.language_model, model.lm_head).eval()
with torch.no_grad():
traced = torch.jit.trace(wrapper, example, strict=False, check_trace=False)
package_path = output_dir / f"surya_decode_step_fp16_cache{args.max_cache_length}.mlpackage"
mlmodel = ct.convert(
traced,
convert_to="mlprogram",
minimum_deployment_target=ct.target.macOS14,
compute_precision=ct.precision.FLOAT16,
compute_units=ct.ComputeUnit.CPU_ONLY,
skip_model_load=True,
inputs=decode_input_specs(example),
outputs=decode_output_specs(),
)
mlmodel.save(str(package_path))
manifest = {
"model_id": args.model_id,
"mode": "decode_step",
"source_dtype": "bf16",
"coreml_compute_precision": "fp16",
"max_cache_length": args.max_cache_length,
"prompt_tokens_for_trace": int(sample["input_ids"].shape[1]),
"package": str(package_path),
"inputs": [spec.name for spec in decode_input_specs(example)],
"outputs": [spec.name for spec in decode_output_specs()],
"state_contract": {
"full_attention_layers": SuryaCoreMLDecodeStepFlat.full_layers,
"linear_attention_layers": SuryaCoreMLDecodeStepFlat.linear_layers,
"host_updates_full_kv_cache": True,
"host_updates_token_position_and_attention_mask": True,
},
}
(output_dir / f"surya_decode_step_fp16_cache{args.max_cache_length}.json").write_text(
json.dumps(manifest, indent=2) + "\n",
encoding="utf-8",
)
print(json.dumps(manifest, indent=2), flush=True)
def smoke_decode_step(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample, example = build_decode_examples(model, processor, args.max_cache_length)
wrapper = SuryaCoreMLDecodeStepFlat(model.model.language_model, model.lm_head).eval()
with torch.no_grad():
torch_outputs = wrapper(*example)
mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
specs = decode_input_specs(example)
feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)}
coreml_outputs = mlmodel.predict(feed)
logits = coreml_outputs["logits"]
torch_logits = torch_outputs[0].detach().cpu().numpy()
report = {
"package": str(args.package.expanduser().resolve()),
"prompt_tokens": int(sample["input_ids"].shape[1]),
"torch_argmax": int(torch_logits[:, -1, :].argmax(axis=-1)[0]),
"coreml_argmax": int(logits[:, -1, :].argmax(axis=-1)[0]),
"logits_max_abs_diff": float(np.max(np.abs(torch_logits - logits))),
"logits_mean_abs_diff": float(np.mean(np.abs(torch_logits - logits))),
"outputs_seen": sorted(coreml_outputs.keys()),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def iterative_decode_smoke(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample = build_sample(processor)
with torch.no_grad():
prefill = model(**sample, use_cache=True, return_dict=True)
current_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values)
full_keys, full_values = pad_full_caches(full_keys, full_values, args.max_cache_length)
mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
generated_coreml = [int(current_token.item())]
generated_native = [int(current_token.item())]
cache_len = int(sample["input_ids"].shape[1])
native_cache = prefill.past_key_values
native_token = current_token.clone()
for _ in range(args.steps):
with torch.no_grad():
inputs_embeds = model.model.language_model.embed_tokens(current_token)
position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long)
cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids)
attention_mask = decode_attention_mask(cache_len, args.max_cache_length)
example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states)
specs = decode_input_specs(example)
feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)}
coreml_outputs = mlmodel.predict(feed)
logits = coreml_outputs["logits"]
next_token = int(logits[:, -1, :].argmax(axis=-1)[0])
if cache_len >= args.max_cache_length:
raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}")
for i in range(SuryaCoreMLDecodeStepFlat.full_layers):
full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(coreml_outputs[f"new_full_key_{i}"])
full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(coreml_outputs[f"new_full_value_{i}"])
conv_states = [torch.from_numpy(coreml_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
recurrent_states = [
torch.from_numpy(coreml_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)
]
cache_len += 1
current_token = torch.tensor([[next_token]], dtype=torch.long)
generated_coreml.append(next_token)
with torch.no_grad():
native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True)
native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
native_cache = native.past_key_values
generated_native.append(int(native_token.item()))
report = {
"package": str(args.package.expanduser().resolve()),
"prompt_tokens": int(sample["input_ids"].shape[1]),
"steps": args.steps,
"coreml_tokens": generated_coreml,
"native_tokens": generated_native,
"prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b),
"all_tokens_match": generated_coreml == generated_native,
"coreml_text": processor.decode(generated_coreml, skip_special_tokens=False),
"native_text": processor.decode(generated_native, skip_special_tokens=False),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def combined_runtime_smoke(args) -> None:
dtype = torch.float32 if args.dtype == "float32" else torch.float16
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = load_model(args.model_id, dtype)
sample, prefill_example = build_prefill_example(model, processor)
prefill_ml = ct.models.MLModel(str(args.prefill_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
decode_ml = ct.models.MLModel(str(args.decode_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY)
prefill_specs = prefill_input_specs(prefill_example)
prefill_feed = {
spec.name: tensor.detach().cpu().numpy().astype(np.float32)
for spec, tensor in zip(prefill_specs, prefill_example, strict=True)
}
prefill_outputs = prefill_ml.predict(prefill_feed)
prefill_logits = prefill_outputs["logits"]
current_token = torch.tensor([[int(prefill_logits[:, -1, :].argmax(axis=-1)[0])]], dtype=torch.long)
full_keys = [torch.from_numpy(prefill_outputs[f"full_key_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)]
full_values = [torch.from_numpy(prefill_outputs[f"full_value_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)]
conv_states = [torch.from_numpy(prefill_outputs[f"conv_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)]
recurrent_states = [
torch.from_numpy(prefill_outputs[f"recurrent_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)
]
with torch.no_grad():
native_prefill = model(**sample, use_cache=True, return_dict=True)
native_token = native_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
native_cache = native_prefill.past_key_values
cache_len = int(sample["input_ids"].shape[1])
generated_coreml = [int(current_token.item())]
generated_native = [int(native_token.item())]
for _ in range(args.steps):
with torch.no_grad():
inputs_embeds = model.model.language_model.embed_tokens(current_token)
position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long)
cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids)
attention_mask = decode_attention_mask(cache_len, args.max_cache_length)
example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states)
specs = decode_input_specs(example)
feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)}
decode_outputs = decode_ml.predict(feed)
logits = decode_outputs["logits"]
next_token = int(logits[:, -1, :].argmax(axis=-1)[0])
if cache_len >= args.max_cache_length:
raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}")
for i in range(SuryaCoreMLDecodeStepFlat.full_layers):
full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_key_{i}"])
full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_value_{i}"])
conv_states = [torch.from_numpy(decode_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)]
recurrent_states = [
torch.from_numpy(decode_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)
]
cache_len += 1
current_token = torch.tensor([[next_token]], dtype=torch.long)
generated_coreml.append(next_token)
with torch.no_grad():
native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True)
native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone()
native_cache = native.past_key_values
generated_native.append(int(native_token.item()))
report = {
"prefill_package": str(args.prefill_package.expanduser().resolve()),
"decode_package": str(args.decode_package.expanduser().resolve()),
"prompt_tokens": int(sample["input_ids"].shape[1]),
"steps": args.steps,
"prefill_coreml_first_token": generated_coreml[0],
"prefill_native_first_token": generated_native[0],
"coreml_tokens": generated_coreml,
"native_tokens": generated_native,
"prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b),
"all_tokens_match": generated_coreml == generated_native,
"coreml_text": processor.decode(generated_coreml, skip_special_tokens=False),
"native_text": processor.decode(generated_native, skip_special_tokens=False),
"host_responsibilities": [
"tokenization",
"token embedding lookup for generated tokens",
"rotary position embedding generation",
"full-attention KV cache insertion",
],
}
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2), flush=True)
def main() -> None:
parser = argparse.ArgumentParser(description="Build and validate a faithful Surya CoreML OCR runtime scaffold.")
parser.add_argument("--model-id", default="datalab-to/surya-ocr-2")
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
sub = parser.add_subparsers(dest="command", required=True)
check = sub.add_parser("check-decode-parity")
check.add_argument("--output", type=Path, required=True)
check.add_argument("--max-cache-length", type=int, default=512)
prefill_check = sub.add_parser("check-prefill-parity")
prefill_check.add_argument("--output", type=Path, required=True)
prefill_check.add_argument("--max-cache-length", type=int, default=512)
prefill_export = sub.add_parser("export-prefill")
prefill_export.add_argument("--output-dir", type=Path, required=True)
prefill_export.add_argument("--max-cache-length", type=int, default=512)
prefill_smoke = sub.add_parser("smoke-prefill")
prefill_smoke.add_argument("--package", type=Path, required=True)
prefill_smoke.add_argument("--output", type=Path, required=True)
prefill_smoke.add_argument("--max-cache-length", type=int, default=512)
export = sub.add_parser("export-decode-step")
export.add_argument("--output-dir", type=Path, required=True)
export.add_argument("--max-cache-length", type=int, default=512)
smoke = sub.add_parser("smoke-decode-step")
smoke.add_argument("--package", type=Path, required=True)
smoke.add_argument("--output", type=Path, required=True)
smoke.add_argument("--max-cache-length", type=int, default=512)
loop = sub.add_parser("iterative-decode-smoke")
loop.add_argument("--package", type=Path, required=True)
loop.add_argument("--output", type=Path, required=True)
loop.add_argument("--max-cache-length", type=int, default=512)
loop.add_argument("--steps", type=int, default=8)
combined = sub.add_parser("combined-runtime-smoke")
combined.add_argument("--prefill-package", type=Path, required=True)
combined.add_argument("--decode-package", type=Path, required=True)
combined.add_argument("--output", type=Path, required=True)
combined.add_argument("--max-cache-length", type=int, default=512)
combined.add_argument("--steps", type=int, default=8)
vision_combined = sub.add_parser("vision-combined-runtime-smoke")
vision_combined.add_argument("--vision-package", type=Path, required=True)
vision_combined.add_argument("--prefill-package", type=Path, required=True)
vision_combined.add_argument("--decode-package", type=Path, required=True)
vision_combined.add_argument("--output", type=Path, required=True)
vision_combined.add_argument("--max-cache-length", type=int, default=512)
vision_combined.add_argument("--steps", type=int, default=8)
args = parser.parse_args()
if args.command == "check-decode-parity":
parity_check(args)
elif args.command == "check-prefill-parity":
check_prefill_parity(args)
elif args.command == "export-prefill":
export_prefill(args)
elif args.command == "smoke-prefill":
smoke_prefill(args)
elif args.command == "export-decode-step":
export_decode_step(args)
elif args.command == "smoke-decode-step":
smoke_decode_step(args)
elif args.command == "iterative-decode-smoke":
iterative_decode_smoke(args)
elif args.command == "combined-runtime-smoke":
combined_runtime_smoke(args)
elif args.command == "vision-combined-runtime-smoke":
vision_combined_runtime_smoke(args)
if __name__ == "__main__":
main()