#!/usr/bin/env python3 import argparse import json from pathlib import Path import coremltools as ct import numpy as np import torch import torch.nn.functional as F from PIL import Image, ImageDraw from qwen_vl_utils import process_vision_info from transformers import AutoConfig, AutoProcessor, Qwen3_5ForConditionalGeneration PROMPT = "OCR this image to HTML." def load_model(model_id: str, dtype: torch.dtype): config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) config._attn_implementation = "eager" if hasattr(config, "text_config"): config.text_config._attn_implementation = "eager" if hasattr(config, "vision_config"): config.vision_config._attn_implementation = "eager" model = Qwen3_5ForConditionalGeneration.from_pretrained( model_id, config=config, torch_dtype=dtype, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True, attn_implementation="eager", ).eval() model.config._attn_implementation = "eager" if hasattr(model.config, "text_config"): model.config.text_config._attn_implementation = "eager" return model def build_sample(processor): image = Image.new("RGB", (512, 512), "white") draw = ImageDraw.Draw(image) draw.text((40, 80), "Invoice 123\nTotal $42.00", fill="black") messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": PROMPT}]}] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt") class SuryaCoreMLDecodeStep(torch.nn.Module): def __init__(self, language_model, lm_head): super().__init__() self.language_model = language_model self.lm_head = lm_head @staticmethod def rotate_half_static(x, rotary_dim): half_dim = rotary_dim // 2 x1 = x.narrow(-1, 0, half_dim) x2 = x.narrow(-1, half_dim, half_dim) return torch.cat((-x2, x1), dim=-1) def apply_rotary_static(self, q, k, cos, sin): cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) rotary_dim = int(cos.shape[-1]) q_rot = q.narrow(-1, 0, rotary_dim) q_pass = q.narrow(-1, rotary_dim, int(q.shape[-1]) - rotary_dim) k_rot = k.narrow(-1, 0, rotary_dim) k_pass = k.narrow(-1, rotary_dim, int(k.shape[-1]) - rotary_dim) q_embed = (q_rot * cos) + (self.rotate_half_static(q_rot, rotary_dim) * sin) k_embed = (k_rot * cos) + (self.rotate_half_static(k_rot, rotary_dim) * sin) return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1) @staticmethod def repeat_kv_static(hidden_states, n_rep): if n_rep == 1: return hidden_states batch = int(hidden_states.shape[0]) num_key_value_heads = int(hidden_states.shape[1]) seq_len = int(hidden_states.shape[2]) head_dim = int(hidden_states.shape[3]) hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim) def full_attention_decode(self, attn, hidden_states, cos, sin, past_key, past_value): query_states, gate = torch.chunk(attn.q_proj(hidden_states).view(1, 1, -1, attn.head_dim * 2), 2, dim=-1) gate = gate.reshape(1, 1, -1) query_states = attn.q_norm(query_states.view(1, 1, -1, attn.head_dim)).transpose(1, 2) key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, 1, -1, attn.head_dim)).transpose(1, 2) value_states = attn.v_proj(hidden_states).view(1, 1, -1, attn.head_dim).transpose(1, 2) query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin) all_key = torch.cat((past_key, key_states), dim=2) all_value = torch.cat((past_value, value_states), dim=2) key_for_attn = self.repeat_kv_static(all_key, attn.num_key_value_groups) value_for_attn = self.repeat_kv_static(all_value, attn.num_key_value_groups) attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_for_attn) attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, -1) attn_output = attn_output * torch.sigmoid(gate) return attn.o_proj(attn_output), all_key, all_value def full_attention_decode_fixed(self, attn, hidden_states, cos, sin, past_key, past_value, attention_mask): query_states, gate = torch.chunk(attn.q_proj(hidden_states).view(1, 1, -1, attn.head_dim * 2), 2, dim=-1) gate = gate.reshape(1, 1, -1) query_states = attn.q_norm(query_states.view(1, 1, -1, attn.head_dim)).transpose(1, 2) key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, 1, -1, attn.head_dim)).transpose(1, 2) value_states = attn.v_proj(hidden_states).view(1, 1, -1, attn.head_dim).transpose(1, 2) query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin) all_key = torch.cat((past_key, key_states), dim=2) all_value = torch.cat((past_value, value_states), dim=2) key_for_attn = self.repeat_kv_static(all_key, attn.num_key_value_groups) value_for_attn = self.repeat_kv_static(all_value, attn.num_key_value_groups) attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling attn_weights = attn_weights + attention_mask.to(attn_weights.dtype) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_for_attn) attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, -1) attn_output = attn_output * torch.sigmoid(gate) return attn.o_proj(attn_output), key_states, value_states @staticmethod def l2norm_static(x): return x * torch.rsqrt((x * x).sum(dim=-1, keepdim=True) + 1e-6) def recurrent_gated_delta_decode(self, query, key, value, g, beta, recurrent_state): initial_dtype = query.dtype query = self.l2norm_static(query) key = self.l2norm_static(key) query = query.transpose(1, 2).contiguous().float() * (1 / (query.shape[-1] ** 0.5)) key = key.transpose(1, 2).contiguous().float() value = value.transpose(1, 2).contiguous().float() beta = beta.transpose(1, 2).contiguous().float() g = g.transpose(1, 2).contiguous().float() q_t = query[:, :, 0] k_t = key[:, :, 0] v_t = value[:, :, 0] g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) beta_t = beta[:, :, 0].unsqueeze(-1) recurrent_state = recurrent_state.float() * g_t kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * beta_t recurrent_state = recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) core_attn_out = (recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) return core_attn_out.unsqueeze(2).transpose(1, 2).contiguous().to(initial_dtype), recurrent_state def gated_delta_decode(self, linear_attn, hidden_states, conv_state, recurrent_state): mixed_qkv_raw = linear_attn.in_proj_qkv(hidden_states).transpose(1, 2) conv_input = torch.cat((conv_state, mixed_qkv_raw), dim=-1) mixed_qkv = F.conv1d( conv_input.to(linear_attn.conv1d.weight.dtype), linear_attn.conv1d.weight, linear_attn.conv1d.bias, padding=0, groups=int(conv_input.shape[1]), ) mixed_qkv = F.silu(mixed_qkv[:, :, -1:]).to(mixed_qkv_raw.dtype) new_conv_state = conv_input[:, :, -linear_attn.conv_kernel_size :] mixed_qkv = mixed_qkv.transpose(1, 2) z = linear_attn.in_proj_z(hidden_states).reshape(1, 1, -1, linear_attn.head_v_dim) b = linear_attn.in_proj_b(hidden_states) a = linear_attn.in_proj_a(hidden_states) query, key, value = torch.split( mixed_qkv, [linear_attn.key_dim, linear_attn.key_dim, linear_attn.value_dim], dim=-1, ) query = query.reshape(1, 1, -1, linear_attn.head_k_dim) key = key.reshape(1, 1, -1, linear_attn.head_k_dim) value = value.reshape(1, 1, -1, linear_attn.head_v_dim) beta = b.sigmoid() g = -linear_attn.A_log.float().exp() * F.softplus(a.float() + linear_attn.dt_bias) core_attn_out, new_recurrent_state = self.recurrent_gated_delta_decode(query, key, value, g, beta, recurrent_state) core_attn_out = core_attn_out.reshape(-1, linear_attn.head_v_dim) z = z.reshape(-1, linear_attn.head_v_dim) normed = linear_attn.norm(core_attn_out, z) normed = normed.reshape(1, 1, -1) return linear_attn.out_proj(normed), new_conv_state, new_recurrent_state def decode_lists(self, inputs_embeds, cos, sin, full_keys, full_values, conv_states, recurrent_states): hidden_states = inputs_embeds next_full_keys = [] next_full_values = [] next_conv_states = [] next_recurrent_states = [] full_idx = 0 linear_idx = 0 for layer in self.language_model.layers: residual = hidden_states hidden_states = layer.input_layernorm(hidden_states) if layer.layer_type == "linear_attention": hidden_states, new_conv, new_recurrent = self.gated_delta_decode( layer.linear_attn, hidden_states, conv_states[linear_idx], recurrent_states[linear_idx], ) next_conv_states.append(new_conv) next_recurrent_states.append(new_recurrent) linear_idx += 1 else: hidden_states, new_key, new_value = self.full_attention_decode( layer.self_attn, hidden_states, cos, sin, full_keys[full_idx], full_values[full_idx], ) next_full_keys.append(new_key) next_full_values.append(new_value) full_idx += 1 hidden_states = residual + hidden_states residual = hidden_states hidden_states = layer.post_attention_layernorm(hidden_states) hidden_states = layer.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.language_model.norm(hidden_states) logits = self.lm_head(hidden_states[:, -1:, :]) return logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states class SuryaCoreMLDecodeStepFlat(SuryaCoreMLDecodeStep): full_layers = 6 linear_layers = 18 def forward(self, inputs_embeds, cos, sin, attention_mask, *states): full_keys = list(states[: self.full_layers]) full_values = list(states[self.full_layers : self.full_layers * 2]) linear_offset = self.full_layers * 2 conv_states = list(states[linear_offset : linear_offset + self.linear_layers]) recurrent_states = list(states[linear_offset + self.linear_layers : linear_offset + self.linear_layers * 2]) logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states = self.decode_lists_fixed( inputs_embeds, cos, sin, full_keys, full_values, attention_mask, conv_states, recurrent_states, ) return tuple([logits] + next_full_keys + next_full_values + next_conv_states + next_recurrent_states) def decode_lists_fixed(self, inputs_embeds, cos, sin, full_keys, full_values, attention_mask, conv_states, recurrent_states): hidden_states = inputs_embeds next_full_keys = [] next_full_values = [] next_conv_states = [] next_recurrent_states = [] full_idx = 0 linear_idx = 0 for layer in self.language_model.layers: residual = hidden_states hidden_states = layer.input_layernorm(hidden_states) if layer.layer_type == "linear_attention": hidden_states, new_conv, new_recurrent = self.gated_delta_decode( layer.linear_attn, hidden_states, conv_states[linear_idx], recurrent_states[linear_idx], ) next_conv_states.append(new_conv) next_recurrent_states.append(new_recurrent) linear_idx += 1 else: hidden_states, new_key, new_value = self.full_attention_decode_fixed( layer.self_attn, hidden_states, cos, sin, full_keys[full_idx], full_values[full_idx], attention_mask, ) next_full_keys.append(new_key) next_full_values.append(new_value) full_idx += 1 hidden_states = residual + hidden_states residual = hidden_states hidden_states = layer.post_attention_layernorm(hidden_states) hidden_states = layer.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.language_model.norm(hidden_states) logits = self.lm_head(hidden_states[:, -1:, :]) return logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states class SuryaCoreMLPrefillFlat(SuryaCoreMLDecodeStep): full_layers = 6 linear_layers = 18 def __init__(self, language_model, lm_head, seq_len, max_cache_length): super().__init__(language_model, lm_head) self.seq_len = seq_len self.max_cache_length = max_cache_length mask = torch.full((1, 1, seq_len, seq_len), torch.finfo(torch.float32).min) self.register_buffer("causal_mask", torch.triu(mask, diagonal=1)) def full_attention_prefill(self, attn, hidden_states, cos, sin): query_states, gate = torch.chunk( attn.q_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim * 2), 2, dim=-1, ) gate = gate.reshape(1, self.seq_len, -1) query_states = attn.q_norm(query_states.view(1, self.seq_len, -1, attn.head_dim)).transpose(1, 2) key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim)).transpose(1, 2) value_states = attn.v_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim).transpose(1, 2) query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin) key_for_attn = self.repeat_kv_static(key_states, attn.num_key_value_groups) value_for_attn = self.repeat_kv_static(value_states, attn.num_key_value_groups) attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling attn_weights = attn_weights + self.causal_mask.to(attn_weights.dtype) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_for_attn) attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, self.seq_len, -1) attn_output = attn_output * torch.sigmoid(gate) key_pad = torch.zeros( 1, int(key_states.shape[1]), self.max_cache_length, int(key_states.shape[3]), dtype=key_states.dtype, device=key_states.device, ) value_pad = torch.zeros( 1, int(value_states.shape[1]), self.max_cache_length, int(value_states.shape[3]), dtype=value_states.dtype, device=value_states.device, ) key_pad[:, :, : self.seq_len, :] = key_states value_pad[:, :, : self.seq_len, :] = value_states return attn.o_proj(attn_output), key_pad, value_pad def chunk_gated_delta_rule_prefill(self, query, key, value, g, beta): chunk_size = 64 sequence_length = self.seq_len pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size total_sequence_length = sequence_length + pad_size num_chunks = total_sequence_length // chunk_size initial_dtype = query.dtype query = self.l2norm_static(query) key = self.l2norm_static(key) query = query.transpose(1, 2).contiguous().float() key = key.transpose(1, 2).contiguous().float() value = value.transpose(1, 2).contiguous().float() beta = beta.transpose(1, 2).contiguous().float() g = g.transpose(1, 2).contiguous().float() query = F.pad(query, (0, 0, 0, pad_size)) key = F.pad(key, (0, 0, 0, pad_size)) value = F.pad(value, (0, 0, 0, pad_size)) beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) query = query * (1 / (128**0.5)) v_beta = value * beta.unsqueeze(-1) k_beta = key * beta.unsqueeze(-1) query = query.reshape(1, 16, num_chunks, chunk_size, 128) key = key.reshape(1, 16, num_chunks, chunk_size, 128) value = value.reshape(1, 16, num_chunks, chunk_size, 128) k_beta = k_beta.reshape(1, 16, num_chunks, chunk_size, 128) v_beta = v_beta.reshape(1, 16, num_chunks, chunk_size, 128) g = g.reshape(1, 16, num_chunks, chunk_size) tri0 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) tri1 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) g_cum = g.cumsum(dim=-1) decay_mask = ((g_cum.unsqueeze(-1) - g_cum.unsqueeze(-2)).tril().exp().float()).tril() attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(tri0, 0) for i in range(1, chunk_size): row = attn[..., i, :i].clone() sub = attn[..., :i, :i].clone() update = row + (row.unsqueeze(-1) * sub).sum(-2) new_row = torch.cat((update, attn[..., i, i:]), dim=-1) attn = torch.cat((attn[..., :i, :], new_row.unsqueeze(-2), attn[..., i + 1 :, :]), dim=-2) attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=query.device) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g_cum.exp().unsqueeze(-1)) last_recurrent_state = torch.zeros(1, 16, 128, 128, dtype=value.dtype, device=value.device) outs = [] for i in range(num_chunks): q_i = query[:, :, i] k_i = key[:, :, i] v_i = value[:, :, i] attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill(tri1, 0) v_prime = k_cumdecay[:, :, i] @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g_cum[:, :, i, :, None].exp()) @ last_recurrent_state outs.append(attn_inter + attn_i @ v_new) last_recurrent_state = ( last_recurrent_state * g_cum[:, :, i, -1, None, None].exp() + (k_i * (g_cum[:, :, i, -1, None] - g_cum[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new ) core_attn_out = torch.stack(outs, dim=2).reshape(1, 16, total_sequence_length, 128) core_attn_out = core_attn_out[:, :, :sequence_length].transpose(1, 2).contiguous().to(initial_dtype) return core_attn_out, last_recurrent_state def gated_delta_prefill(self, linear_attn, hidden_states): mixed_qkv_raw = linear_attn.in_proj_qkv(hidden_states).transpose(1, 2) z = linear_attn.in_proj_z(hidden_states).reshape(1, self.seq_len, -1, linear_attn.head_v_dim) b = linear_attn.in_proj_b(hidden_states) a = linear_attn.in_proj_a(hidden_states) mixed_qkv = F.silu(linear_attn.conv1d(mixed_qkv_raw)[:, :, : self.seq_len]).transpose(1, 2) query, key, value = torch.split(mixed_qkv, [linear_attn.key_dim, linear_attn.key_dim, linear_attn.value_dim], dim=-1) query = query.reshape(1, self.seq_len, -1, linear_attn.head_k_dim) key = key.reshape(1, self.seq_len, -1, linear_attn.head_k_dim) value = value.reshape(1, self.seq_len, -1, linear_attn.head_v_dim) beta = b.sigmoid() g = -linear_attn.A_log.float().exp() * F.softplus(a.float() + linear_attn.dt_bias) core_attn_out, recurrent_state = self.chunk_gated_delta_rule_prefill(query, key, value, g, beta) core_attn_out = core_attn_out.reshape(-1, linear_attn.head_v_dim) z = z.reshape(-1, linear_attn.head_v_dim) normed = linear_attn.norm(core_attn_out, z).reshape(1, self.seq_len, -1) conv_state = mixed_qkv_raw[:, :, -linear_attn.conv_kernel_size :] return linear_attn.out_proj(normed), conv_state, recurrent_state def forward(self, inputs_embeds, cos, sin): hidden_states = inputs_embeds full_keys = [] full_values = [] conv_states = [] recurrent_states = [] for layer in self.language_model.layers: residual = hidden_states hidden_states = layer.input_layernorm(hidden_states) if layer.layer_type == "linear_attention": hidden_states, conv_state, recurrent_state = self.gated_delta_prefill(layer.linear_attn, hidden_states) conv_states.append(conv_state) recurrent_states.append(recurrent_state) else: hidden_states, key_state, value_state = self.full_attention_prefill(layer.self_attn, hidden_states, cos, sin) full_keys.append(key_state) full_values.append(value_state) hidden_states = residual + hidden_states residual = hidden_states hidden_states = layer.post_attention_layernorm(hidden_states) hidden_states = layer.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.language_model.norm(hidden_states) logits = self.lm_head(hidden_states[:, -1:, :]) return tuple([logits] + full_keys + full_values + conv_states + recurrent_states) def extract_cache_lists(cache): full_keys = [] full_values = [] conv_states = [] recurrent_states = [] for layer in cache.layers: if hasattr(layer, "keys"): full_keys.append(layer.keys.detach().clone()) full_values.append(layer.values.detach().clone()) else: conv_states.append(layer.conv_states.detach().clone()) recurrent_states.append(layer.recurrent_states.detach().clone()) return full_keys, full_values, conv_states, recurrent_states def pad_full_caches(full_keys, full_values, max_cache_length): padded_keys = [] padded_values = [] for key, value in zip(full_keys, full_values, strict=True): if key.shape[2] > max_cache_length: raise ValueError(f"full attention cache length {key.shape[2]} exceeds max_cache_length={max_cache_length}") key_pad = torch.zeros(key.shape[0], key.shape[1], max_cache_length, key.shape[3], dtype=key.dtype) value_pad = torch.zeros(value.shape[0], value.shape[1], max_cache_length, value.shape[3], dtype=value.dtype) key_pad[:, :, : key.shape[2], :] = key value_pad[:, :, : value.shape[2], :] = value padded_keys.append(key_pad) padded_values.append(value_pad) return padded_keys, padded_values def decode_attention_mask(cache_length, max_cache_length, dtype=torch.float32): mask = torch.full((1, 1, 1, max_cache_length + 1), torch.finfo(dtype).min, dtype=dtype) mask[:, :, :, :cache_length] = 0 mask[:, :, :, -1:] = 0 return mask def parity_check(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample = build_sample(processor) with torch.no_grad(): prefill = model(**sample, use_cache=True, return_dict=True) first_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values) with torch.no_grad(): native = model(input_ids=first_token, past_key_values=prefill.past_key_values, use_cache=True, return_dict=True) step = SuryaCoreMLDecodeStep(model.model.language_model, model.lm_head).eval() inputs_embeds = model.model.language_model.embed_tokens(first_token) position_ids = torch.full((3, 1, 1), int(sample["input_ids"].shape[1]), dtype=torch.long) cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) with torch.no_grad(): logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states = step.decode_lists( inputs_embeds, cos, sin, full_keys, full_values, conv_states, recurrent_states, ) padded_full_keys, padded_full_values = pad_full_caches(full_keys, full_values, args.max_cache_length) attention_mask = decode_attention_mask(int(sample["input_ids"].shape[1]), args.max_cache_length) with torch.no_grad(): fixed_logits, fixed_next_keys, fixed_next_values, fixed_next_conv, fixed_next_recurrent = step.decode_lists_fixed( inputs_embeds, cos, sin, padded_full_keys, padded_full_values, attention_mask, conv_states, recurrent_states, ) report = { "prompt_tokens": int(sample["input_ids"].shape[1]), "first_token": int(first_token.item()), "native_next_argmax": int(native.logits[:, -1, :].argmax(dim=-1).item()), "custom_next_argmax": int(logits[:, -1, :].argmax(dim=-1).item()), "logits_max_abs_diff": float((native.logits[:, -1:, :] - logits).abs().max().item()), "logits_mean_abs_diff": float((native.logits[:, -1:, :] - logits).abs().mean().item()), "fixed_cache_length": args.max_cache_length, "fixed_native_max_abs_diff": float((native.logits[:, -1:, :] - fixed_logits).abs().max().item()), "fixed_native_mean_abs_diff": float((native.logits[:, -1:, :] - fixed_logits).abs().mean().item()), "fixed_growing_max_abs_diff": float((logits - fixed_logits).abs().max().item()), "fixed_custom_next_argmax": int(fixed_logits[:, -1, :].argmax(dim=-1).item()), "full_layers": len(next_full_keys), "linear_layers": len(next_conv_states), "next_full_key_shape": list(next_full_keys[0].shape), "next_conv_shape": list(next_conv_states[0].shape), "next_recurrent_shape": list(next_recurrent_states[0].shape), } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def build_decode_examples(model, processor, max_cache_length): sample = build_sample(processor) with torch.no_grad(): prefill = model(**sample, use_cache=True, return_dict=True) first_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values) padded_full_keys, padded_full_values = pad_full_caches(full_keys, full_values, max_cache_length) inputs_embeds = model.model.language_model.embed_tokens(first_token) position_ids = torch.full((3, 1, 1), int(sample["input_ids"].shape[1]), dtype=torch.long) cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) attention_mask = decode_attention_mask(int(sample["input_ids"].shape[1]), max_cache_length) state_examples = tuple(padded_full_keys + padded_full_values + conv_states + recurrent_states) return sample, (inputs_embeds, cos, sin, attention_mask) + state_examples def decode_input_specs(example): names = ["inputs_embeds", "cos", "sin", "attention_mask"] names += [f"full_key_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] names += [f"full_value_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] names += [f"conv_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] names += [f"recurrent_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] return [ct.TensorType(name=name, shape=tuple(t.shape), dtype=np.float32) for name, t in zip(names, example, strict=True)] def decode_output_specs(): names = ["logits"] names += [f"new_full_key_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] names += [f"new_full_value_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] names += [f"new_conv_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] names += [f"new_recurrent_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] return [ct.TensorType(name=name) for name in names] def build_prefill_example(model, processor): sample = build_sample(processor) with torch.no_grad(): image_outputs = model.model.get_image_features( sample["pixel_values"].to(next(model.parameters()).dtype), sample["image_grid_thw"], return_dict=True, ) image_embeds = torch.cat(image_outputs.pooler_output, dim=0) return build_prefill_example_from_image_embeds(model, sample, image_embeds) def build_prefill_example_from_image_embeds(model, sample, image_embeds): input_ids = sample["input_ids"] attention_mask = sample["attention_mask"] mm_token_type_ids = sample["mm_token_type_ids"] image_grid_thw = sample["image_grid_thw"] with torch.no_grad(): inputs_embeds = model.model.get_input_embeddings()(input_ids) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = model.model.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) position_ids = model.model.compute_3d_position_ids( input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=None, mm_token_type_ids=mm_token_type_ids, ) cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) return sample, (inputs_embeds, cos, sin) def prefill_input_specs(example): names = ["inputs_embeds", "cos", "sin"] return [ct.TensorType(name=name, shape=tuple(t.shape), dtype=np.float32) for name, t in zip(names, example, strict=True)] def prefill_output_specs(): names = ["logits"] names += [f"full_key_{i}" for i in range(SuryaCoreMLPrefillFlat.full_layers)] names += [f"full_value_{i}" for i in range(SuryaCoreMLPrefillFlat.full_layers)] names += [f"conv_state_{i}" for i in range(SuryaCoreMLPrefillFlat.linear_layers)] names += [f"recurrent_state_{i}" for i in range(SuryaCoreMLPrefillFlat.linear_layers)] return [ct.TensorType(name=name) for name in names] def check_prefill_parity(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample, example = build_prefill_example(model, processor) wrapper = SuryaCoreMLPrefillFlat( model.model.language_model, model.lm_head, int(example[0].shape[1]), args.max_cache_length, ).eval() with torch.no_grad(): native = model(**sample, use_cache=True, return_dict=True) custom_outputs = wrapper(*example) native_full_keys, native_full_values, native_conv, native_recurrent = extract_cache_lists(native.past_key_values) native_full_keys, native_full_values = pad_full_caches(native_full_keys, native_full_values, args.max_cache_length) custom_logits = custom_outputs[0] custom_full_keys = list(custom_outputs[1:7]) custom_full_values = list(custom_outputs[7:13]) custom_conv = list(custom_outputs[13:31]) custom_recurrent = list(custom_outputs[31:49]) report = { "prompt_tokens": int(sample["input_ids"].shape[1]), "max_cache_length": args.max_cache_length, "native_argmax": int(native.logits[:, -1, :].argmax(dim=-1).item()), "custom_argmax": int(custom_logits[:, -1, :].argmax(dim=-1).item()), "logits_max_abs_diff": float((native.logits[:, -1:, :] - custom_logits).abs().max().item()), "logits_mean_abs_diff": float((native.logits[:, -1:, :] - custom_logits).abs().mean().item()), "full_key0_max_abs_diff": float((native_full_keys[0] - custom_full_keys[0]).abs().max().item()), "full_value0_max_abs_diff": float((native_full_values[0] - custom_full_values[0]).abs().max().item()), "conv0_max_abs_diff": float((native_conv[0] - custom_conv[0]).abs().max().item()), "recurrent0_max_abs_diff": float((native_recurrent[0] - custom_recurrent[0]).abs().max().item()), "full_layers": len(custom_full_keys), "linear_layers": len(custom_conv), } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def export_prefill(args) -> None: output_dir = args.output_dir.expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample, example = build_prefill_example(model, processor) wrapper = SuryaCoreMLPrefillFlat( model.model.language_model, model.lm_head, int(example[0].shape[1]), args.max_cache_length, ).eval() with torch.no_grad(): traced = torch.jit.trace(wrapper, example, strict=False, check_trace=False) package_path = output_dir / f"surya_prefill_fp16_seq{int(example[0].shape[1])}_cache{args.max_cache_length}.mlpackage" mlmodel = ct.convert( traced, convert_to="mlprogram", minimum_deployment_target=ct.target.macOS14, compute_precision=ct.precision.FLOAT16, compute_units=ct.ComputeUnit.CPU_ONLY, skip_model_load=True, inputs=prefill_input_specs(example), outputs=prefill_output_specs(), ) mlmodel.save(str(package_path)) manifest = { "model_id": args.model_id, "mode": "prefill", "source_dtype": "bf16", "coreml_compute_precision": "fp16", "seq_len": int(example[0].shape[1]), "max_cache_length": args.max_cache_length, "package": str(package_path), "inputs": [spec.name for spec in prefill_input_specs(example)], "outputs": [spec.name for spec in prefill_output_specs()], } (output_dir / f"surya_prefill_fp16_seq{int(example[0].shape[1])}_cache{args.max_cache_length}.json").write_text( json.dumps(manifest, indent=2) + "\n", encoding="utf-8", ) print(json.dumps(manifest, indent=2), flush=True) def smoke_prefill(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample, example = build_prefill_example(model, processor) wrapper = SuryaCoreMLPrefillFlat(model.model.language_model, model.lm_head, int(example[0].shape[1]), args.max_cache_length).eval() with torch.no_grad(): torch_outputs = wrapper(*example) mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) specs = prefill_input_specs(example) feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} coreml_outputs = mlmodel.predict(feed) torch_logits = torch_outputs[0].detach().cpu().numpy() logits = coreml_outputs["logits"] report = { "package": str(args.package.expanduser().resolve()), "prompt_tokens": int(sample["input_ids"].shape[1]), "torch_argmax": int(torch_logits[:, -1, :].argmax(axis=-1)[0]), "coreml_argmax": int(logits[:, -1, :].argmax(axis=-1)[0]), "logits_max_abs_diff": float(np.max(np.abs(torch_logits - logits))), "logits_mean_abs_diff": float(np.mean(np.abs(torch_logits - logits))), "outputs_seen": sorted(coreml_outputs.keys()), } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def vision_combined_runtime_smoke(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample = build_sample(processor) vision_ml = ct.models.MLModel(str(args.vision_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) prefill_ml = ct.models.MLModel(str(args.prefill_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) decode_ml = ct.models.MLModel(str(args.decode_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) pixel_values = sample["pixel_values"].detach().cpu().numpy().astype(np.float32) vision_outputs = vision_ml.predict({"pixel_values": pixel_values}) image_embeds_np = vision_outputs["image_embeds"] image_embeds = torch.from_numpy(image_embeds_np) prefill_example_sample, prefill_example = build_prefill_example_from_image_embeds(model, sample, image_embeds) with torch.no_grad(): torch_visual = model.model.visual( sample["pixel_values"].to(next(model.parameters()).dtype), sample["image_grid_thw"], ).pooler_output.detach().cpu().numpy() prefill_specs = prefill_input_specs(prefill_example) prefill_feed = { spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(prefill_specs, prefill_example, strict=True) } prefill_outputs = prefill_ml.predict(prefill_feed) prefill_logits = prefill_outputs["logits"] current_token = torch.tensor([[int(prefill_logits[:, -1, :].argmax(axis=-1)[0])]], dtype=torch.long) full_keys = [torch.from_numpy(prefill_outputs[f"full_key_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] full_values = [torch.from_numpy(prefill_outputs[f"full_value_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] conv_states = [torch.from_numpy(prefill_outputs[f"conv_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)] recurrent_states = [ torch.from_numpy(prefill_outputs[f"recurrent_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers) ] with torch.no_grad(): native_prefill = model(**sample, use_cache=True, return_dict=True) native_token = native_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() native_cache = native_prefill.past_key_values cache_len = int(prefill_example_sample["input_ids"].shape[1]) generated_coreml = [int(current_token.item())] generated_native = [int(native_token.item())] for _ in range(args.steps): with torch.no_grad(): inputs_embeds = model.model.language_model.embed_tokens(current_token) position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long) cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) attention_mask = decode_attention_mask(cache_len, args.max_cache_length) example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states) specs = decode_input_specs(example) feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} decode_outputs = decode_ml.predict(feed) logits = decode_outputs["logits"] next_token = int(logits[:, -1, :].argmax(axis=-1)[0]) if cache_len >= args.max_cache_length: raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}") for i in range(SuryaCoreMLDecodeStepFlat.full_layers): full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_key_{i}"]) full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_value_{i}"]) conv_states = [torch.from_numpy(decode_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] recurrent_states = [ torch.from_numpy(decode_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers) ] cache_len += 1 current_token = torch.tensor([[next_token]], dtype=torch.long) generated_coreml.append(next_token) with torch.no_grad(): native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True) native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() native_cache = native.past_key_values generated_native.append(int(native_token.item())) report = { "vision_package": str(args.vision_package.expanduser().resolve()), "prefill_package": str(args.prefill_package.expanduser().resolve()), "decode_package": str(args.decode_package.expanduser().resolve()), "prompt_tokens": int(prefill_example_sample["input_ids"].shape[1]), "steps": args.steps, "vision_shape": list(image_embeds_np.shape), "vision_max_abs_diff_vs_torch": float(np.max(np.abs(torch_visual - image_embeds_np))), "vision_mean_abs_diff_vs_torch": float(np.mean(np.abs(torch_visual - image_embeds_np))), "prefill_coreml_first_token": generated_coreml[0], "prefill_native_first_token": generated_native[0], "coreml_tokens": generated_coreml, "native_tokens": generated_native, "prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b), "all_tokens_match": generated_coreml == generated_native, "coreml_text": processor.decode(generated_coreml, skip_special_tokens=False), "native_text": processor.decode(generated_native, skip_special_tokens=False), "host_responsibilities": [ "tokenization", "initial text token embedding lookup", "image placeholder insertion", "rotary position embedding generation", "generated-token embedding lookup", "full-attention KV cache insertion", ], } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def export_decode_step(args) -> None: output_dir = args.output_dir.expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample, example = build_decode_examples(model, processor, args.max_cache_length) wrapper = SuryaCoreMLDecodeStepFlat(model.model.language_model, model.lm_head).eval() with torch.no_grad(): traced = torch.jit.trace(wrapper, example, strict=False, check_trace=False) package_path = output_dir / f"surya_decode_step_fp16_cache{args.max_cache_length}.mlpackage" mlmodel = ct.convert( traced, convert_to="mlprogram", minimum_deployment_target=ct.target.macOS14, compute_precision=ct.precision.FLOAT16, compute_units=ct.ComputeUnit.CPU_ONLY, skip_model_load=True, inputs=decode_input_specs(example), outputs=decode_output_specs(), ) mlmodel.save(str(package_path)) manifest = { "model_id": args.model_id, "mode": "decode_step", "source_dtype": "bf16", "coreml_compute_precision": "fp16", "max_cache_length": args.max_cache_length, "prompt_tokens_for_trace": int(sample["input_ids"].shape[1]), "package": str(package_path), "inputs": [spec.name for spec in decode_input_specs(example)], "outputs": [spec.name for spec in decode_output_specs()], "state_contract": { "full_attention_layers": SuryaCoreMLDecodeStepFlat.full_layers, "linear_attention_layers": SuryaCoreMLDecodeStepFlat.linear_layers, "host_updates_full_kv_cache": True, "host_updates_token_position_and_attention_mask": True, }, } (output_dir / f"surya_decode_step_fp16_cache{args.max_cache_length}.json").write_text( json.dumps(manifest, indent=2) + "\n", encoding="utf-8", ) print(json.dumps(manifest, indent=2), flush=True) def smoke_decode_step(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample, example = build_decode_examples(model, processor, args.max_cache_length) wrapper = SuryaCoreMLDecodeStepFlat(model.model.language_model, model.lm_head).eval() with torch.no_grad(): torch_outputs = wrapper(*example) mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) specs = decode_input_specs(example) feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} coreml_outputs = mlmodel.predict(feed) logits = coreml_outputs["logits"] torch_logits = torch_outputs[0].detach().cpu().numpy() report = { "package": str(args.package.expanduser().resolve()), "prompt_tokens": int(sample["input_ids"].shape[1]), "torch_argmax": int(torch_logits[:, -1, :].argmax(axis=-1)[0]), "coreml_argmax": int(logits[:, -1, :].argmax(axis=-1)[0]), "logits_max_abs_diff": float(np.max(np.abs(torch_logits - logits))), "logits_mean_abs_diff": float(np.mean(np.abs(torch_logits - logits))), "outputs_seen": sorted(coreml_outputs.keys()), } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def iterative_decode_smoke(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample = build_sample(processor) with torch.no_grad(): prefill = model(**sample, use_cache=True, return_dict=True) current_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values) full_keys, full_values = pad_full_caches(full_keys, full_values, args.max_cache_length) mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) generated_coreml = [int(current_token.item())] generated_native = [int(current_token.item())] cache_len = int(sample["input_ids"].shape[1]) native_cache = prefill.past_key_values native_token = current_token.clone() for _ in range(args.steps): with torch.no_grad(): inputs_embeds = model.model.language_model.embed_tokens(current_token) position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long) cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) attention_mask = decode_attention_mask(cache_len, args.max_cache_length) example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states) specs = decode_input_specs(example) feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} coreml_outputs = mlmodel.predict(feed) logits = coreml_outputs["logits"] next_token = int(logits[:, -1, :].argmax(axis=-1)[0]) if cache_len >= args.max_cache_length: raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}") for i in range(SuryaCoreMLDecodeStepFlat.full_layers): full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(coreml_outputs[f"new_full_key_{i}"]) full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(coreml_outputs[f"new_full_value_{i}"]) conv_states = [torch.from_numpy(coreml_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] recurrent_states = [ torch.from_numpy(coreml_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers) ] cache_len += 1 current_token = torch.tensor([[next_token]], dtype=torch.long) generated_coreml.append(next_token) with torch.no_grad(): native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True) native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() native_cache = native.past_key_values generated_native.append(int(native_token.item())) report = { "package": str(args.package.expanduser().resolve()), "prompt_tokens": int(sample["input_ids"].shape[1]), "steps": args.steps, "coreml_tokens": generated_coreml, "native_tokens": generated_native, "prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b), "all_tokens_match": generated_coreml == generated_native, "coreml_text": processor.decode(generated_coreml, skip_special_tokens=False), "native_text": processor.decode(generated_native, skip_special_tokens=False), } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def combined_runtime_smoke(args) -> None: dtype = torch.float32 if args.dtype == "float32" else torch.float16 processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = load_model(args.model_id, dtype) sample, prefill_example = build_prefill_example(model, processor) prefill_ml = ct.models.MLModel(str(args.prefill_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) decode_ml = ct.models.MLModel(str(args.decode_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) prefill_specs = prefill_input_specs(prefill_example) prefill_feed = { spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(prefill_specs, prefill_example, strict=True) } prefill_outputs = prefill_ml.predict(prefill_feed) prefill_logits = prefill_outputs["logits"] current_token = torch.tensor([[int(prefill_logits[:, -1, :].argmax(axis=-1)[0])]], dtype=torch.long) full_keys = [torch.from_numpy(prefill_outputs[f"full_key_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] full_values = [torch.from_numpy(prefill_outputs[f"full_value_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] conv_states = [torch.from_numpy(prefill_outputs[f"conv_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)] recurrent_states = [ torch.from_numpy(prefill_outputs[f"recurrent_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers) ] with torch.no_grad(): native_prefill = model(**sample, use_cache=True, return_dict=True) native_token = native_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() native_cache = native_prefill.past_key_values cache_len = int(sample["input_ids"].shape[1]) generated_coreml = [int(current_token.item())] generated_native = [int(native_token.item())] for _ in range(args.steps): with torch.no_grad(): inputs_embeds = model.model.language_model.embed_tokens(current_token) position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long) cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) attention_mask = decode_attention_mask(cache_len, args.max_cache_length) example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states) specs = decode_input_specs(example) feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} decode_outputs = decode_ml.predict(feed) logits = decode_outputs["logits"] next_token = int(logits[:, -1, :].argmax(axis=-1)[0]) if cache_len >= args.max_cache_length: raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}") for i in range(SuryaCoreMLDecodeStepFlat.full_layers): full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_key_{i}"]) full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_value_{i}"]) conv_states = [torch.from_numpy(decode_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] recurrent_states = [ torch.from_numpy(decode_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers) ] cache_len += 1 current_token = torch.tensor([[next_token]], dtype=torch.long) generated_coreml.append(next_token) with torch.no_grad(): native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True) native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() native_cache = native.past_key_values generated_native.append(int(native_token.item())) report = { "prefill_package": str(args.prefill_package.expanduser().resolve()), "decode_package": str(args.decode_package.expanduser().resolve()), "prompt_tokens": int(sample["input_ids"].shape[1]), "steps": args.steps, "prefill_coreml_first_token": generated_coreml[0], "prefill_native_first_token": generated_native[0], "coreml_tokens": generated_coreml, "native_tokens": generated_native, "prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b), "all_tokens_match": generated_coreml == generated_native, "coreml_text": processor.decode(generated_coreml, skip_special_tokens=False), "native_text": processor.decode(generated_native, skip_special_tokens=False), "host_responsibilities": [ "tokenization", "token embedding lookup for generated tokens", "rotary position embedding generation", "full-attention KV cache insertion", ], } args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") print(json.dumps(report, indent=2), flush=True) def main() -> None: parser = argparse.ArgumentParser(description="Build and validate a faithful Surya CoreML OCR runtime scaffold.") parser.add_argument("--model-id", default="datalab-to/surya-ocr-2") parser.add_argument("--dtype", choices=["float32", "float16"], default="float32") sub = parser.add_subparsers(dest="command", required=True) check = sub.add_parser("check-decode-parity") check.add_argument("--output", type=Path, required=True) check.add_argument("--max-cache-length", type=int, default=512) prefill_check = sub.add_parser("check-prefill-parity") prefill_check.add_argument("--output", type=Path, required=True) prefill_check.add_argument("--max-cache-length", type=int, default=512) prefill_export = sub.add_parser("export-prefill") prefill_export.add_argument("--output-dir", type=Path, required=True) prefill_export.add_argument("--max-cache-length", type=int, default=512) prefill_smoke = sub.add_parser("smoke-prefill") prefill_smoke.add_argument("--package", type=Path, required=True) prefill_smoke.add_argument("--output", type=Path, required=True) prefill_smoke.add_argument("--max-cache-length", type=int, default=512) export = sub.add_parser("export-decode-step") export.add_argument("--output-dir", type=Path, required=True) export.add_argument("--max-cache-length", type=int, default=512) smoke = sub.add_parser("smoke-decode-step") smoke.add_argument("--package", type=Path, required=True) smoke.add_argument("--output", type=Path, required=True) smoke.add_argument("--max-cache-length", type=int, default=512) loop = sub.add_parser("iterative-decode-smoke") loop.add_argument("--package", type=Path, required=True) loop.add_argument("--output", type=Path, required=True) loop.add_argument("--max-cache-length", type=int, default=512) loop.add_argument("--steps", type=int, default=8) combined = sub.add_parser("combined-runtime-smoke") combined.add_argument("--prefill-package", type=Path, required=True) combined.add_argument("--decode-package", type=Path, required=True) combined.add_argument("--output", type=Path, required=True) combined.add_argument("--max-cache-length", type=int, default=512) combined.add_argument("--steps", type=int, default=8) vision_combined = sub.add_parser("vision-combined-runtime-smoke") vision_combined.add_argument("--vision-package", type=Path, required=True) vision_combined.add_argument("--prefill-package", type=Path, required=True) vision_combined.add_argument("--decode-package", type=Path, required=True) vision_combined.add_argument("--output", type=Path, required=True) vision_combined.add_argument("--max-cache-length", type=int, default=512) vision_combined.add_argument("--steps", type=int, default=8) args = parser.parse_args() if args.command == "check-decode-parity": parity_check(args) elif args.command == "check-prefill-parity": check_prefill_parity(args) elif args.command == "export-prefill": export_prefill(args) elif args.command == "smoke-prefill": smoke_prefill(args) elif args.command == "export-decode-step": export_decode_step(args) elif args.command == "smoke-decode-step": smoke_decode_step(args) elif args.command == "iterative-decode-smoke": iterative_decode_smoke(args) elif args.command == "combined-runtime-smoke": combined_runtime_smoke(args) elif args.command == "vision-combined-runtime-smoke": vision_combined_runtime_smoke(args) if __name__ == "__main__": main()