| |
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import coremltools as ct |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image, ImageDraw |
| from qwen_vl_utils import process_vision_info |
| from transformers import AutoConfig, AutoProcessor, Qwen3_5ForConditionalGeneration |
|
|
|
|
| PROMPT = "OCR this image to HTML." |
|
|
|
|
| def load_model(model_id: str, dtype: torch.dtype): |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) |
| config._attn_implementation = "eager" |
| if hasattr(config, "text_config"): |
| config.text_config._attn_implementation = "eager" |
| if hasattr(config, "vision_config"): |
| config.vision_config._attn_implementation = "eager" |
| model = Qwen3_5ForConditionalGeneration.from_pretrained( |
| model_id, |
| config=config, |
| torch_dtype=dtype, |
| device_map="cpu", |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| attn_implementation="eager", |
| ).eval() |
| model.config._attn_implementation = "eager" |
| if hasattr(model.config, "text_config"): |
| model.config.text_config._attn_implementation = "eager" |
| return model |
|
|
|
|
| def build_sample(processor): |
| image = Image.new("RGB", (512, 512), "white") |
| draw = ImageDraw.Draw(image) |
| draw.text((40, 80), "Invoice 123\nTotal $42.00", fill="black") |
| messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": PROMPT}]}] |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| image_inputs, video_inputs = process_vision_info(messages) |
| return processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt") |
|
|
|
|
| class SuryaCoreMLDecodeStep(torch.nn.Module): |
| def __init__(self, language_model, lm_head): |
| super().__init__() |
| self.language_model = language_model |
| self.lm_head = lm_head |
|
|
| @staticmethod |
| def rotate_half_static(x, rotary_dim): |
| half_dim = rotary_dim // 2 |
| x1 = x.narrow(-1, 0, half_dim) |
| x2 = x.narrow(-1, half_dim, half_dim) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def apply_rotary_static(self, q, k, cos, sin): |
| cos = cos.unsqueeze(1) |
| sin = sin.unsqueeze(1) |
| rotary_dim = int(cos.shape[-1]) |
| q_rot = q.narrow(-1, 0, rotary_dim) |
| q_pass = q.narrow(-1, rotary_dim, int(q.shape[-1]) - rotary_dim) |
| k_rot = k.narrow(-1, 0, rotary_dim) |
| k_pass = k.narrow(-1, rotary_dim, int(k.shape[-1]) - rotary_dim) |
| q_embed = (q_rot * cos) + (self.rotate_half_static(q_rot, rotary_dim) * sin) |
| k_embed = (k_rot * cos) + (self.rotate_half_static(k_rot, rotary_dim) * sin) |
| return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1) |
|
|
| @staticmethod |
| def repeat_kv_static(hidden_states, n_rep): |
| if n_rep == 1: |
| return hidden_states |
| batch = int(hidden_states.shape[0]) |
| num_key_value_heads = int(hidden_states.shape[1]) |
| seq_len = int(hidden_states.shape[2]) |
| head_dim = int(hidden_states.shape[3]) |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim) |
|
|
| def full_attention_decode(self, attn, hidden_states, cos, sin, past_key, past_value): |
| query_states, gate = torch.chunk(attn.q_proj(hidden_states).view(1, 1, -1, attn.head_dim * 2), 2, dim=-1) |
| gate = gate.reshape(1, 1, -1) |
| query_states = attn.q_norm(query_states.view(1, 1, -1, attn.head_dim)).transpose(1, 2) |
| key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, 1, -1, attn.head_dim)).transpose(1, 2) |
| value_states = attn.v_proj(hidden_states).view(1, 1, -1, attn.head_dim).transpose(1, 2) |
|
|
| query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin) |
| all_key = torch.cat((past_key, key_states), dim=2) |
| all_value = torch.cat((past_value, value_states), dim=2) |
| key_for_attn = self.repeat_kv_static(all_key, attn.num_key_value_groups) |
| value_for_attn = self.repeat_kv_static(all_value, attn.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling |
| attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_output = torch.matmul(attn_weights, value_for_attn) |
| attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, -1) |
| attn_output = attn_output * torch.sigmoid(gate) |
| return attn.o_proj(attn_output), all_key, all_value |
|
|
| def full_attention_decode_fixed(self, attn, hidden_states, cos, sin, past_key, past_value, attention_mask): |
| query_states, gate = torch.chunk(attn.q_proj(hidden_states).view(1, 1, -1, attn.head_dim * 2), 2, dim=-1) |
| gate = gate.reshape(1, 1, -1) |
| query_states = attn.q_norm(query_states.view(1, 1, -1, attn.head_dim)).transpose(1, 2) |
| key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, 1, -1, attn.head_dim)).transpose(1, 2) |
| value_states = attn.v_proj(hidden_states).view(1, 1, -1, attn.head_dim).transpose(1, 2) |
|
|
| query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin) |
| all_key = torch.cat((past_key, key_states), dim=2) |
| all_value = torch.cat((past_value, value_states), dim=2) |
| key_for_attn = self.repeat_kv_static(all_key, attn.num_key_value_groups) |
| value_for_attn = self.repeat_kv_static(all_value, attn.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling |
| attn_weights = attn_weights + attention_mask.to(attn_weights.dtype) |
| attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_output = torch.matmul(attn_weights, value_for_attn) |
| attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, 1, -1) |
| attn_output = attn_output * torch.sigmoid(gate) |
| return attn.o_proj(attn_output), key_states, value_states |
|
|
| @staticmethod |
| def l2norm_static(x): |
| return x * torch.rsqrt((x * x).sum(dim=-1, keepdim=True) + 1e-6) |
|
|
| def recurrent_gated_delta_decode(self, query, key, value, g, beta, recurrent_state): |
| initial_dtype = query.dtype |
| query = self.l2norm_static(query) |
| key = self.l2norm_static(key) |
| query = query.transpose(1, 2).contiguous().float() * (1 / (query.shape[-1] ** 0.5)) |
| key = key.transpose(1, 2).contiguous().float() |
| value = value.transpose(1, 2).contiguous().float() |
| beta = beta.transpose(1, 2).contiguous().float() |
| g = g.transpose(1, 2).contiguous().float() |
|
|
| q_t = query[:, :, 0] |
| k_t = key[:, :, 0] |
| v_t = value[:, :, 0] |
| g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) |
| beta_t = beta[:, :, 0].unsqueeze(-1) |
|
|
| recurrent_state = recurrent_state.float() * g_t |
| kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) |
| delta = (v_t - kv_mem) * beta_t |
| recurrent_state = recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) |
| core_attn_out = (recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) |
| return core_attn_out.unsqueeze(2).transpose(1, 2).contiguous().to(initial_dtype), recurrent_state |
|
|
| def gated_delta_decode(self, linear_attn, hidden_states, conv_state, recurrent_state): |
| mixed_qkv_raw = linear_attn.in_proj_qkv(hidden_states).transpose(1, 2) |
| conv_input = torch.cat((conv_state, mixed_qkv_raw), dim=-1) |
| mixed_qkv = F.conv1d( |
| conv_input.to(linear_attn.conv1d.weight.dtype), |
| linear_attn.conv1d.weight, |
| linear_attn.conv1d.bias, |
| padding=0, |
| groups=int(conv_input.shape[1]), |
| ) |
| mixed_qkv = F.silu(mixed_qkv[:, :, -1:]).to(mixed_qkv_raw.dtype) |
| new_conv_state = conv_input[:, :, -linear_attn.conv_kernel_size :] |
| mixed_qkv = mixed_qkv.transpose(1, 2) |
|
|
| z = linear_attn.in_proj_z(hidden_states).reshape(1, 1, -1, linear_attn.head_v_dim) |
| b = linear_attn.in_proj_b(hidden_states) |
| a = linear_attn.in_proj_a(hidden_states) |
| query, key, value = torch.split( |
| mixed_qkv, |
| [linear_attn.key_dim, linear_attn.key_dim, linear_attn.value_dim], |
| dim=-1, |
| ) |
| query = query.reshape(1, 1, -1, linear_attn.head_k_dim) |
| key = key.reshape(1, 1, -1, linear_attn.head_k_dim) |
| value = value.reshape(1, 1, -1, linear_attn.head_v_dim) |
| beta = b.sigmoid() |
| g = -linear_attn.A_log.float().exp() * F.softplus(a.float() + linear_attn.dt_bias) |
|
|
| core_attn_out, new_recurrent_state = self.recurrent_gated_delta_decode(query, key, value, g, beta, recurrent_state) |
| core_attn_out = core_attn_out.reshape(-1, linear_attn.head_v_dim) |
| z = z.reshape(-1, linear_attn.head_v_dim) |
| normed = linear_attn.norm(core_attn_out, z) |
| normed = normed.reshape(1, 1, -1) |
| return linear_attn.out_proj(normed), new_conv_state, new_recurrent_state |
|
|
| def decode_lists(self, inputs_embeds, cos, sin, full_keys, full_values, conv_states, recurrent_states): |
| hidden_states = inputs_embeds |
| next_full_keys = [] |
| next_full_values = [] |
| next_conv_states = [] |
| next_recurrent_states = [] |
| full_idx = 0 |
| linear_idx = 0 |
| for layer in self.language_model.layers: |
| residual = hidden_states |
| hidden_states = layer.input_layernorm(hidden_states) |
| if layer.layer_type == "linear_attention": |
| hidden_states, new_conv, new_recurrent = self.gated_delta_decode( |
| layer.linear_attn, |
| hidden_states, |
| conv_states[linear_idx], |
| recurrent_states[linear_idx], |
| ) |
| next_conv_states.append(new_conv) |
| next_recurrent_states.append(new_recurrent) |
| linear_idx += 1 |
| else: |
| hidden_states, new_key, new_value = self.full_attention_decode( |
| layer.self_attn, |
| hidden_states, |
| cos, |
| sin, |
| full_keys[full_idx], |
| full_values[full_idx], |
| ) |
| next_full_keys.append(new_key) |
| next_full_values.append(new_value) |
| full_idx += 1 |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| hidden_states = self.language_model.norm(hidden_states) |
| logits = self.lm_head(hidden_states[:, -1:, :]) |
| return logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states |
|
|
|
|
| class SuryaCoreMLDecodeStepFlat(SuryaCoreMLDecodeStep): |
| full_layers = 6 |
| linear_layers = 18 |
|
|
| def forward(self, inputs_embeds, cos, sin, attention_mask, *states): |
| full_keys = list(states[: self.full_layers]) |
| full_values = list(states[self.full_layers : self.full_layers * 2]) |
| linear_offset = self.full_layers * 2 |
| conv_states = list(states[linear_offset : linear_offset + self.linear_layers]) |
| recurrent_states = list(states[linear_offset + self.linear_layers : linear_offset + self.linear_layers * 2]) |
| logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states = self.decode_lists_fixed( |
| inputs_embeds, |
| cos, |
| sin, |
| full_keys, |
| full_values, |
| attention_mask, |
| conv_states, |
| recurrent_states, |
| ) |
| return tuple([logits] + next_full_keys + next_full_values + next_conv_states + next_recurrent_states) |
|
|
| def decode_lists_fixed(self, inputs_embeds, cos, sin, full_keys, full_values, attention_mask, conv_states, recurrent_states): |
| hidden_states = inputs_embeds |
| next_full_keys = [] |
| next_full_values = [] |
| next_conv_states = [] |
| next_recurrent_states = [] |
| full_idx = 0 |
| linear_idx = 0 |
| for layer in self.language_model.layers: |
| residual = hidden_states |
| hidden_states = layer.input_layernorm(hidden_states) |
| if layer.layer_type == "linear_attention": |
| hidden_states, new_conv, new_recurrent = self.gated_delta_decode( |
| layer.linear_attn, |
| hidden_states, |
| conv_states[linear_idx], |
| recurrent_states[linear_idx], |
| ) |
| next_conv_states.append(new_conv) |
| next_recurrent_states.append(new_recurrent) |
| linear_idx += 1 |
| else: |
| hidden_states, new_key, new_value = self.full_attention_decode_fixed( |
| layer.self_attn, |
| hidden_states, |
| cos, |
| sin, |
| full_keys[full_idx], |
| full_values[full_idx], |
| attention_mask, |
| ) |
| next_full_keys.append(new_key) |
| next_full_values.append(new_value) |
| full_idx += 1 |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| hidden_states = self.language_model.norm(hidden_states) |
| logits = self.lm_head(hidden_states[:, -1:, :]) |
| return logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states |
|
|
|
|
| class SuryaCoreMLPrefillFlat(SuryaCoreMLDecodeStep): |
| full_layers = 6 |
| linear_layers = 18 |
|
|
| def __init__(self, language_model, lm_head, seq_len, max_cache_length): |
| super().__init__(language_model, lm_head) |
| self.seq_len = seq_len |
| self.max_cache_length = max_cache_length |
| mask = torch.full((1, 1, seq_len, seq_len), torch.finfo(torch.float32).min) |
| self.register_buffer("causal_mask", torch.triu(mask, diagonal=1)) |
|
|
| def full_attention_prefill(self, attn, hidden_states, cos, sin): |
| query_states, gate = torch.chunk( |
| attn.q_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim * 2), |
| 2, |
| dim=-1, |
| ) |
| gate = gate.reshape(1, self.seq_len, -1) |
|
|
| query_states = attn.q_norm(query_states.view(1, self.seq_len, -1, attn.head_dim)).transpose(1, 2) |
| key_states = attn.k_norm(attn.k_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim)).transpose(1, 2) |
| value_states = attn.v_proj(hidden_states).view(1, self.seq_len, -1, attn.head_dim).transpose(1, 2) |
|
|
| query_states, key_states = self.apply_rotary_static(query_states, key_states, cos, sin) |
| key_for_attn = self.repeat_kv_static(key_states, attn.num_key_value_groups) |
| value_for_attn = self.repeat_kv_static(value_states, attn.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_for_attn.transpose(2, 3)) * attn.scaling |
| attn_weights = attn_weights + self.causal_mask.to(attn_weights.dtype) |
| attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_output = torch.matmul(attn_weights, value_for_attn) |
| attn_output = attn_output.transpose(1, 2).contiguous().reshape(1, self.seq_len, -1) |
| attn_output = attn_output * torch.sigmoid(gate) |
|
|
| key_pad = torch.zeros( |
| 1, |
| int(key_states.shape[1]), |
| self.max_cache_length, |
| int(key_states.shape[3]), |
| dtype=key_states.dtype, |
| device=key_states.device, |
| ) |
| value_pad = torch.zeros( |
| 1, |
| int(value_states.shape[1]), |
| self.max_cache_length, |
| int(value_states.shape[3]), |
| dtype=value_states.dtype, |
| device=value_states.device, |
| ) |
| key_pad[:, :, : self.seq_len, :] = key_states |
| value_pad[:, :, : self.seq_len, :] = value_states |
| return attn.o_proj(attn_output), key_pad, value_pad |
|
|
| def chunk_gated_delta_rule_prefill(self, query, key, value, g, beta): |
| chunk_size = 64 |
| sequence_length = self.seq_len |
| pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size |
| total_sequence_length = sequence_length + pad_size |
| num_chunks = total_sequence_length // chunk_size |
|
|
| initial_dtype = query.dtype |
| query = self.l2norm_static(query) |
| key = self.l2norm_static(key) |
| query = query.transpose(1, 2).contiguous().float() |
| key = key.transpose(1, 2).contiguous().float() |
| value = value.transpose(1, 2).contiguous().float() |
| beta = beta.transpose(1, 2).contiguous().float() |
| g = g.transpose(1, 2).contiguous().float() |
|
|
| query = F.pad(query, (0, 0, 0, pad_size)) |
| key = F.pad(key, (0, 0, 0, pad_size)) |
| value = F.pad(value, (0, 0, 0, pad_size)) |
| beta = F.pad(beta, (0, pad_size)) |
| g = F.pad(g, (0, pad_size)) |
|
|
| query = query * (1 / (128**0.5)) |
| v_beta = value * beta.unsqueeze(-1) |
| k_beta = key * beta.unsqueeze(-1) |
| query = query.reshape(1, 16, num_chunks, chunk_size, 128) |
| key = key.reshape(1, 16, num_chunks, chunk_size, 128) |
| value = value.reshape(1, 16, num_chunks, chunk_size, 128) |
| k_beta = k_beta.reshape(1, 16, num_chunks, chunk_size, 128) |
| v_beta = v_beta.reshape(1, 16, num_chunks, chunk_size, 128) |
| g = g.reshape(1, 16, num_chunks, chunk_size) |
|
|
| tri0 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) |
| tri1 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) |
| g_cum = g.cumsum(dim=-1) |
| decay_mask = ((g_cum.unsqueeze(-1) - g_cum.unsqueeze(-2)).tril().exp().float()).tril() |
| attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(tri0, 0) |
| for i in range(1, chunk_size): |
| row = attn[..., i, :i].clone() |
| sub = attn[..., :i, :i].clone() |
| update = row + (row.unsqueeze(-1) * sub).sum(-2) |
| new_row = torch.cat((update, attn[..., i, i:]), dim=-1) |
| attn = torch.cat((attn[..., :i, :], new_row.unsqueeze(-2), attn[..., i + 1 :, :]), dim=-2) |
| attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=query.device) |
| value = attn @ v_beta |
| k_cumdecay = attn @ (k_beta * g_cum.exp().unsqueeze(-1)) |
|
|
| last_recurrent_state = torch.zeros(1, 16, 128, 128, dtype=value.dtype, device=value.device) |
| outs = [] |
| for i in range(num_chunks): |
| q_i = query[:, :, i] |
| k_i = key[:, :, i] |
| v_i = value[:, :, i] |
| attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill(tri1, 0) |
| v_prime = k_cumdecay[:, :, i] @ last_recurrent_state |
| v_new = v_i - v_prime |
| attn_inter = (q_i * g_cum[:, :, i, :, None].exp()) @ last_recurrent_state |
| outs.append(attn_inter + attn_i @ v_new) |
| last_recurrent_state = ( |
| last_recurrent_state * g_cum[:, :, i, -1, None, None].exp() |
| + (k_i * (g_cum[:, :, i, -1, None] - g_cum[:, :, i]).exp()[..., None]).transpose(-1, -2) |
| @ v_new |
| ) |
|
|
| core_attn_out = torch.stack(outs, dim=2).reshape(1, 16, total_sequence_length, 128) |
| core_attn_out = core_attn_out[:, :, :sequence_length].transpose(1, 2).contiguous().to(initial_dtype) |
| return core_attn_out, last_recurrent_state |
|
|
| def gated_delta_prefill(self, linear_attn, hidden_states): |
| mixed_qkv_raw = linear_attn.in_proj_qkv(hidden_states).transpose(1, 2) |
| z = linear_attn.in_proj_z(hidden_states).reshape(1, self.seq_len, -1, linear_attn.head_v_dim) |
| b = linear_attn.in_proj_b(hidden_states) |
| a = linear_attn.in_proj_a(hidden_states) |
|
|
| mixed_qkv = F.silu(linear_attn.conv1d(mixed_qkv_raw)[:, :, : self.seq_len]).transpose(1, 2) |
| query, key, value = torch.split(mixed_qkv, [linear_attn.key_dim, linear_attn.key_dim, linear_attn.value_dim], dim=-1) |
| query = query.reshape(1, self.seq_len, -1, linear_attn.head_k_dim) |
| key = key.reshape(1, self.seq_len, -1, linear_attn.head_k_dim) |
| value = value.reshape(1, self.seq_len, -1, linear_attn.head_v_dim) |
| beta = b.sigmoid() |
| g = -linear_attn.A_log.float().exp() * F.softplus(a.float() + linear_attn.dt_bias) |
|
|
| core_attn_out, recurrent_state = self.chunk_gated_delta_rule_prefill(query, key, value, g, beta) |
| core_attn_out = core_attn_out.reshape(-1, linear_attn.head_v_dim) |
| z = z.reshape(-1, linear_attn.head_v_dim) |
| normed = linear_attn.norm(core_attn_out, z).reshape(1, self.seq_len, -1) |
| conv_state = mixed_qkv_raw[:, :, -linear_attn.conv_kernel_size :] |
| return linear_attn.out_proj(normed), conv_state, recurrent_state |
|
|
| def forward(self, inputs_embeds, cos, sin): |
| hidden_states = inputs_embeds |
| full_keys = [] |
| full_values = [] |
| conv_states = [] |
| recurrent_states = [] |
| for layer in self.language_model.layers: |
| residual = hidden_states |
| hidden_states = layer.input_layernorm(hidden_states) |
| if layer.layer_type == "linear_attention": |
| hidden_states, conv_state, recurrent_state = self.gated_delta_prefill(layer.linear_attn, hidden_states) |
| conv_states.append(conv_state) |
| recurrent_states.append(recurrent_state) |
| else: |
| hidden_states, key_state, value_state = self.full_attention_prefill(layer.self_attn, hidden_states, cos, sin) |
| full_keys.append(key_state) |
| full_values.append(value_state) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = layer.post_attention_layernorm(hidden_states) |
| hidden_states = layer.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| hidden_states = self.language_model.norm(hidden_states) |
| logits = self.lm_head(hidden_states[:, -1:, :]) |
| return tuple([logits] + full_keys + full_values + conv_states + recurrent_states) |
|
|
|
|
| def extract_cache_lists(cache): |
| full_keys = [] |
| full_values = [] |
| conv_states = [] |
| recurrent_states = [] |
| for layer in cache.layers: |
| if hasattr(layer, "keys"): |
| full_keys.append(layer.keys.detach().clone()) |
| full_values.append(layer.values.detach().clone()) |
| else: |
| conv_states.append(layer.conv_states.detach().clone()) |
| recurrent_states.append(layer.recurrent_states.detach().clone()) |
| return full_keys, full_values, conv_states, recurrent_states |
|
|
|
|
| def pad_full_caches(full_keys, full_values, max_cache_length): |
| padded_keys = [] |
| padded_values = [] |
| for key, value in zip(full_keys, full_values, strict=True): |
| if key.shape[2] > max_cache_length: |
| raise ValueError(f"full attention cache length {key.shape[2]} exceeds max_cache_length={max_cache_length}") |
| key_pad = torch.zeros(key.shape[0], key.shape[1], max_cache_length, key.shape[3], dtype=key.dtype) |
| value_pad = torch.zeros(value.shape[0], value.shape[1], max_cache_length, value.shape[3], dtype=value.dtype) |
| key_pad[:, :, : key.shape[2], :] = key |
| value_pad[:, :, : value.shape[2], :] = value |
| padded_keys.append(key_pad) |
| padded_values.append(value_pad) |
| return padded_keys, padded_values |
|
|
|
|
| def decode_attention_mask(cache_length, max_cache_length, dtype=torch.float32): |
| mask = torch.full((1, 1, 1, max_cache_length + 1), torch.finfo(dtype).min, dtype=dtype) |
| mask[:, :, :, :cache_length] = 0 |
| mask[:, :, :, -1:] = 0 |
| return mask |
|
|
|
|
| def parity_check(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample = build_sample(processor) |
|
|
| with torch.no_grad(): |
| prefill = model(**sample, use_cache=True, return_dict=True) |
| first_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values) |
| with torch.no_grad(): |
| native = model(input_ids=first_token, past_key_values=prefill.past_key_values, use_cache=True, return_dict=True) |
|
|
| step = SuryaCoreMLDecodeStep(model.model.language_model, model.lm_head).eval() |
| inputs_embeds = model.model.language_model.embed_tokens(first_token) |
| position_ids = torch.full((3, 1, 1), int(sample["input_ids"].shape[1]), dtype=torch.long) |
| cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) |
|
|
| with torch.no_grad(): |
| logits, next_full_keys, next_full_values, next_conv_states, next_recurrent_states = step.decode_lists( |
| inputs_embeds, |
| cos, |
| sin, |
| full_keys, |
| full_values, |
| conv_states, |
| recurrent_states, |
| ) |
| padded_full_keys, padded_full_values = pad_full_caches(full_keys, full_values, args.max_cache_length) |
| attention_mask = decode_attention_mask(int(sample["input_ids"].shape[1]), args.max_cache_length) |
| with torch.no_grad(): |
| fixed_logits, fixed_next_keys, fixed_next_values, fixed_next_conv, fixed_next_recurrent = step.decode_lists_fixed( |
| inputs_embeds, |
| cos, |
| sin, |
| padded_full_keys, |
| padded_full_values, |
| attention_mask, |
| conv_states, |
| recurrent_states, |
| ) |
|
|
| report = { |
| "prompt_tokens": int(sample["input_ids"].shape[1]), |
| "first_token": int(first_token.item()), |
| "native_next_argmax": int(native.logits[:, -1, :].argmax(dim=-1).item()), |
| "custom_next_argmax": int(logits[:, -1, :].argmax(dim=-1).item()), |
| "logits_max_abs_diff": float((native.logits[:, -1:, :] - logits).abs().max().item()), |
| "logits_mean_abs_diff": float((native.logits[:, -1:, :] - logits).abs().mean().item()), |
| "fixed_cache_length": args.max_cache_length, |
| "fixed_native_max_abs_diff": float((native.logits[:, -1:, :] - fixed_logits).abs().max().item()), |
| "fixed_native_mean_abs_diff": float((native.logits[:, -1:, :] - fixed_logits).abs().mean().item()), |
| "fixed_growing_max_abs_diff": float((logits - fixed_logits).abs().max().item()), |
| "fixed_custom_next_argmax": int(fixed_logits[:, -1, :].argmax(dim=-1).item()), |
| "full_layers": len(next_full_keys), |
| "linear_layers": len(next_conv_states), |
| "next_full_key_shape": list(next_full_keys[0].shape), |
| "next_conv_shape": list(next_conv_states[0].shape), |
| "next_recurrent_shape": list(next_recurrent_states[0].shape), |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def build_decode_examples(model, processor, max_cache_length): |
| sample = build_sample(processor) |
| with torch.no_grad(): |
| prefill = model(**sample, use_cache=True, return_dict=True) |
| first_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values) |
| padded_full_keys, padded_full_values = pad_full_caches(full_keys, full_values, max_cache_length) |
| inputs_embeds = model.model.language_model.embed_tokens(first_token) |
| position_ids = torch.full((3, 1, 1), int(sample["input_ids"].shape[1]), dtype=torch.long) |
| cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) |
| attention_mask = decode_attention_mask(int(sample["input_ids"].shape[1]), max_cache_length) |
| state_examples = tuple(padded_full_keys + padded_full_values + conv_states + recurrent_states) |
| return sample, (inputs_embeds, cos, sin, attention_mask) + state_examples |
|
|
|
|
| def decode_input_specs(example): |
| names = ["inputs_embeds", "cos", "sin", "attention_mask"] |
| names += [f"full_key_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] |
| names += [f"full_value_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] |
| names += [f"conv_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| names += [f"recurrent_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| return [ct.TensorType(name=name, shape=tuple(t.shape), dtype=np.float32) for name, t in zip(names, example, strict=True)] |
|
|
|
|
| def decode_output_specs(): |
| names = ["logits"] |
| names += [f"new_full_key_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] |
| names += [f"new_full_value_{i}" for i in range(SuryaCoreMLDecodeStepFlat.full_layers)] |
| names += [f"new_conv_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| names += [f"new_recurrent_state_{i}" for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| return [ct.TensorType(name=name) for name in names] |
|
|
|
|
| def build_prefill_example(model, processor): |
| sample = build_sample(processor) |
| with torch.no_grad(): |
| image_outputs = model.model.get_image_features( |
| sample["pixel_values"].to(next(model.parameters()).dtype), |
| sample["image_grid_thw"], |
| return_dict=True, |
| ) |
| image_embeds = torch.cat(image_outputs.pooler_output, dim=0) |
| return build_prefill_example_from_image_embeds(model, sample, image_embeds) |
|
|
|
|
| def build_prefill_example_from_image_embeds(model, sample, image_embeds): |
| input_ids = sample["input_ids"] |
| attention_mask = sample["attention_mask"] |
| mm_token_type_ids = sample["mm_token_type_ids"] |
| image_grid_thw = sample["image_grid_thw"] |
| with torch.no_grad(): |
| inputs_embeds = model.model.get_input_embeddings()(input_ids) |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| image_mask, _ = model.model.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| position_ids = model.model.compute_3d_position_ids( |
| input_ids=input_ids, |
| image_grid_thw=image_grid_thw, |
| video_grid_thw=None, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| past_key_values=None, |
| mm_token_type_ids=mm_token_type_ids, |
| ) |
| cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) |
| return sample, (inputs_embeds, cos, sin) |
|
|
|
|
| def prefill_input_specs(example): |
| names = ["inputs_embeds", "cos", "sin"] |
| return [ct.TensorType(name=name, shape=tuple(t.shape), dtype=np.float32) for name, t in zip(names, example, strict=True)] |
|
|
|
|
| def prefill_output_specs(): |
| names = ["logits"] |
| names += [f"full_key_{i}" for i in range(SuryaCoreMLPrefillFlat.full_layers)] |
| names += [f"full_value_{i}" for i in range(SuryaCoreMLPrefillFlat.full_layers)] |
| names += [f"conv_state_{i}" for i in range(SuryaCoreMLPrefillFlat.linear_layers)] |
| names += [f"recurrent_state_{i}" for i in range(SuryaCoreMLPrefillFlat.linear_layers)] |
| return [ct.TensorType(name=name) for name in names] |
|
|
|
|
| def check_prefill_parity(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample, example = build_prefill_example(model, processor) |
| wrapper = SuryaCoreMLPrefillFlat( |
| model.model.language_model, |
| model.lm_head, |
| int(example[0].shape[1]), |
| args.max_cache_length, |
| ).eval() |
| with torch.no_grad(): |
| native = model(**sample, use_cache=True, return_dict=True) |
| custom_outputs = wrapper(*example) |
| native_full_keys, native_full_values, native_conv, native_recurrent = extract_cache_lists(native.past_key_values) |
| native_full_keys, native_full_values = pad_full_caches(native_full_keys, native_full_values, args.max_cache_length) |
| custom_logits = custom_outputs[0] |
| custom_full_keys = list(custom_outputs[1:7]) |
| custom_full_values = list(custom_outputs[7:13]) |
| custom_conv = list(custom_outputs[13:31]) |
| custom_recurrent = list(custom_outputs[31:49]) |
| report = { |
| "prompt_tokens": int(sample["input_ids"].shape[1]), |
| "max_cache_length": args.max_cache_length, |
| "native_argmax": int(native.logits[:, -1, :].argmax(dim=-1).item()), |
| "custom_argmax": int(custom_logits[:, -1, :].argmax(dim=-1).item()), |
| "logits_max_abs_diff": float((native.logits[:, -1:, :] - custom_logits).abs().max().item()), |
| "logits_mean_abs_diff": float((native.logits[:, -1:, :] - custom_logits).abs().mean().item()), |
| "full_key0_max_abs_diff": float((native_full_keys[0] - custom_full_keys[0]).abs().max().item()), |
| "full_value0_max_abs_diff": float((native_full_values[0] - custom_full_values[0]).abs().max().item()), |
| "conv0_max_abs_diff": float((native_conv[0] - custom_conv[0]).abs().max().item()), |
| "recurrent0_max_abs_diff": float((native_recurrent[0] - custom_recurrent[0]).abs().max().item()), |
| "full_layers": len(custom_full_keys), |
| "linear_layers": len(custom_conv), |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def export_prefill(args) -> None: |
| output_dir = args.output_dir.expanduser().resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample, example = build_prefill_example(model, processor) |
| wrapper = SuryaCoreMLPrefillFlat( |
| model.model.language_model, |
| model.lm_head, |
| int(example[0].shape[1]), |
| args.max_cache_length, |
| ).eval() |
| with torch.no_grad(): |
| traced = torch.jit.trace(wrapper, example, strict=False, check_trace=False) |
| package_path = output_dir / f"surya_prefill_fp16_seq{int(example[0].shape[1])}_cache{args.max_cache_length}.mlpackage" |
| mlmodel = ct.convert( |
| traced, |
| convert_to="mlprogram", |
| minimum_deployment_target=ct.target.macOS14, |
| compute_precision=ct.precision.FLOAT16, |
| compute_units=ct.ComputeUnit.CPU_ONLY, |
| skip_model_load=True, |
| inputs=prefill_input_specs(example), |
| outputs=prefill_output_specs(), |
| ) |
| mlmodel.save(str(package_path)) |
| manifest = { |
| "model_id": args.model_id, |
| "mode": "prefill", |
| "source_dtype": "bf16", |
| "coreml_compute_precision": "fp16", |
| "seq_len": int(example[0].shape[1]), |
| "max_cache_length": args.max_cache_length, |
| "package": str(package_path), |
| "inputs": [spec.name for spec in prefill_input_specs(example)], |
| "outputs": [spec.name for spec in prefill_output_specs()], |
| } |
| (output_dir / f"surya_prefill_fp16_seq{int(example[0].shape[1])}_cache{args.max_cache_length}.json").write_text( |
| json.dumps(manifest, indent=2) + "\n", |
| encoding="utf-8", |
| ) |
| print(json.dumps(manifest, indent=2), flush=True) |
|
|
|
|
| def smoke_prefill(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample, example = build_prefill_example(model, processor) |
| wrapper = SuryaCoreMLPrefillFlat(model.model.language_model, model.lm_head, int(example[0].shape[1]), args.max_cache_length).eval() |
| with torch.no_grad(): |
| torch_outputs = wrapper(*example) |
| mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
| specs = prefill_input_specs(example) |
| feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} |
| coreml_outputs = mlmodel.predict(feed) |
| torch_logits = torch_outputs[0].detach().cpu().numpy() |
| logits = coreml_outputs["logits"] |
| report = { |
| "package": str(args.package.expanduser().resolve()), |
| "prompt_tokens": int(sample["input_ids"].shape[1]), |
| "torch_argmax": int(torch_logits[:, -1, :].argmax(axis=-1)[0]), |
| "coreml_argmax": int(logits[:, -1, :].argmax(axis=-1)[0]), |
| "logits_max_abs_diff": float(np.max(np.abs(torch_logits - logits))), |
| "logits_mean_abs_diff": float(np.mean(np.abs(torch_logits - logits))), |
| "outputs_seen": sorted(coreml_outputs.keys()), |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def vision_combined_runtime_smoke(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample = build_sample(processor) |
|
|
| vision_ml = ct.models.MLModel(str(args.vision_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
| prefill_ml = ct.models.MLModel(str(args.prefill_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
| decode_ml = ct.models.MLModel(str(args.decode_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
|
|
| pixel_values = sample["pixel_values"].detach().cpu().numpy().astype(np.float32) |
| vision_outputs = vision_ml.predict({"pixel_values": pixel_values}) |
| image_embeds_np = vision_outputs["image_embeds"] |
| image_embeds = torch.from_numpy(image_embeds_np) |
| prefill_example_sample, prefill_example = build_prefill_example_from_image_embeds(model, sample, image_embeds) |
|
|
| with torch.no_grad(): |
| torch_visual = model.model.visual( |
| sample["pixel_values"].to(next(model.parameters()).dtype), |
| sample["image_grid_thw"], |
| ).pooler_output.detach().cpu().numpy() |
| prefill_specs = prefill_input_specs(prefill_example) |
| prefill_feed = { |
| spec.name: tensor.detach().cpu().numpy().astype(np.float32) |
| for spec, tensor in zip(prefill_specs, prefill_example, strict=True) |
| } |
| prefill_outputs = prefill_ml.predict(prefill_feed) |
| prefill_logits = prefill_outputs["logits"] |
| current_token = torch.tensor([[int(prefill_logits[:, -1, :].argmax(axis=-1)[0])]], dtype=torch.long) |
|
|
| full_keys = [torch.from_numpy(prefill_outputs[f"full_key_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] |
| full_values = [torch.from_numpy(prefill_outputs[f"full_value_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] |
| conv_states = [torch.from_numpy(prefill_outputs[f"conv_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)] |
| recurrent_states = [ |
| torch.from_numpy(prefill_outputs[f"recurrent_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers) |
| ] |
|
|
| with torch.no_grad(): |
| native_prefill = model(**sample, use_cache=True, return_dict=True) |
| native_token = native_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| native_cache = native_prefill.past_key_values |
|
|
| cache_len = int(prefill_example_sample["input_ids"].shape[1]) |
| generated_coreml = [int(current_token.item())] |
| generated_native = [int(native_token.item())] |
|
|
| for _ in range(args.steps): |
| with torch.no_grad(): |
| inputs_embeds = model.model.language_model.embed_tokens(current_token) |
| position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long) |
| cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) |
| attention_mask = decode_attention_mask(cache_len, args.max_cache_length) |
| example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states) |
| specs = decode_input_specs(example) |
| feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} |
| decode_outputs = decode_ml.predict(feed) |
| logits = decode_outputs["logits"] |
| next_token = int(logits[:, -1, :].argmax(axis=-1)[0]) |
|
|
| if cache_len >= args.max_cache_length: |
| raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}") |
| for i in range(SuryaCoreMLDecodeStepFlat.full_layers): |
| full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_key_{i}"]) |
| full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_value_{i}"]) |
| conv_states = [torch.from_numpy(decode_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| recurrent_states = [ |
| torch.from_numpy(decode_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers) |
| ] |
| cache_len += 1 |
| current_token = torch.tensor([[next_token]], dtype=torch.long) |
| generated_coreml.append(next_token) |
|
|
| with torch.no_grad(): |
| native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True) |
| native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| native_cache = native.past_key_values |
| generated_native.append(int(native_token.item())) |
|
|
| report = { |
| "vision_package": str(args.vision_package.expanduser().resolve()), |
| "prefill_package": str(args.prefill_package.expanduser().resolve()), |
| "decode_package": str(args.decode_package.expanduser().resolve()), |
| "prompt_tokens": int(prefill_example_sample["input_ids"].shape[1]), |
| "steps": args.steps, |
| "vision_shape": list(image_embeds_np.shape), |
| "vision_max_abs_diff_vs_torch": float(np.max(np.abs(torch_visual - image_embeds_np))), |
| "vision_mean_abs_diff_vs_torch": float(np.mean(np.abs(torch_visual - image_embeds_np))), |
| "prefill_coreml_first_token": generated_coreml[0], |
| "prefill_native_first_token": generated_native[0], |
| "coreml_tokens": generated_coreml, |
| "native_tokens": generated_native, |
| "prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b), |
| "all_tokens_match": generated_coreml == generated_native, |
| "coreml_text": processor.decode(generated_coreml, skip_special_tokens=False), |
| "native_text": processor.decode(generated_native, skip_special_tokens=False), |
| "host_responsibilities": [ |
| "tokenization", |
| "initial text token embedding lookup", |
| "image placeholder insertion", |
| "rotary position embedding generation", |
| "generated-token embedding lookup", |
| "full-attention KV cache insertion", |
| ], |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def export_decode_step(args) -> None: |
| output_dir = args.output_dir.expanduser().resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample, example = build_decode_examples(model, processor, args.max_cache_length) |
| wrapper = SuryaCoreMLDecodeStepFlat(model.model.language_model, model.lm_head).eval() |
| with torch.no_grad(): |
| traced = torch.jit.trace(wrapper, example, strict=False, check_trace=False) |
| package_path = output_dir / f"surya_decode_step_fp16_cache{args.max_cache_length}.mlpackage" |
| mlmodel = ct.convert( |
| traced, |
| convert_to="mlprogram", |
| minimum_deployment_target=ct.target.macOS14, |
| compute_precision=ct.precision.FLOAT16, |
| compute_units=ct.ComputeUnit.CPU_ONLY, |
| skip_model_load=True, |
| inputs=decode_input_specs(example), |
| outputs=decode_output_specs(), |
| ) |
| mlmodel.save(str(package_path)) |
| manifest = { |
| "model_id": args.model_id, |
| "mode": "decode_step", |
| "source_dtype": "bf16", |
| "coreml_compute_precision": "fp16", |
| "max_cache_length": args.max_cache_length, |
| "prompt_tokens_for_trace": int(sample["input_ids"].shape[1]), |
| "package": str(package_path), |
| "inputs": [spec.name for spec in decode_input_specs(example)], |
| "outputs": [spec.name for spec in decode_output_specs()], |
| "state_contract": { |
| "full_attention_layers": SuryaCoreMLDecodeStepFlat.full_layers, |
| "linear_attention_layers": SuryaCoreMLDecodeStepFlat.linear_layers, |
| "host_updates_full_kv_cache": True, |
| "host_updates_token_position_and_attention_mask": True, |
| }, |
| } |
| (output_dir / f"surya_decode_step_fp16_cache{args.max_cache_length}.json").write_text( |
| json.dumps(manifest, indent=2) + "\n", |
| encoding="utf-8", |
| ) |
| print(json.dumps(manifest, indent=2), flush=True) |
|
|
|
|
| def smoke_decode_step(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample, example = build_decode_examples(model, processor, args.max_cache_length) |
| wrapper = SuryaCoreMLDecodeStepFlat(model.model.language_model, model.lm_head).eval() |
| with torch.no_grad(): |
| torch_outputs = wrapper(*example) |
| mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
| specs = decode_input_specs(example) |
| feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} |
| coreml_outputs = mlmodel.predict(feed) |
| logits = coreml_outputs["logits"] |
| torch_logits = torch_outputs[0].detach().cpu().numpy() |
| report = { |
| "package": str(args.package.expanduser().resolve()), |
| "prompt_tokens": int(sample["input_ids"].shape[1]), |
| "torch_argmax": int(torch_logits[:, -1, :].argmax(axis=-1)[0]), |
| "coreml_argmax": int(logits[:, -1, :].argmax(axis=-1)[0]), |
| "logits_max_abs_diff": float(np.max(np.abs(torch_logits - logits))), |
| "logits_mean_abs_diff": float(np.mean(np.abs(torch_logits - logits))), |
| "outputs_seen": sorted(coreml_outputs.keys()), |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def iterative_decode_smoke(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample = build_sample(processor) |
| with torch.no_grad(): |
| prefill = model(**sample, use_cache=True, return_dict=True) |
| current_token = prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| full_keys, full_values, conv_states, recurrent_states = extract_cache_lists(prefill.past_key_values) |
| full_keys, full_values = pad_full_caches(full_keys, full_values, args.max_cache_length) |
|
|
| mlmodel = ct.models.MLModel(str(args.package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
| generated_coreml = [int(current_token.item())] |
| generated_native = [int(current_token.item())] |
|
|
| cache_len = int(sample["input_ids"].shape[1]) |
| native_cache = prefill.past_key_values |
| native_token = current_token.clone() |
|
|
| for _ in range(args.steps): |
| with torch.no_grad(): |
| inputs_embeds = model.model.language_model.embed_tokens(current_token) |
| position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long) |
| cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) |
| attention_mask = decode_attention_mask(cache_len, args.max_cache_length) |
| example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states) |
| specs = decode_input_specs(example) |
| feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} |
| coreml_outputs = mlmodel.predict(feed) |
| logits = coreml_outputs["logits"] |
| next_token = int(logits[:, -1, :].argmax(axis=-1)[0]) |
|
|
| if cache_len >= args.max_cache_length: |
| raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}") |
| for i in range(SuryaCoreMLDecodeStepFlat.full_layers): |
| full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(coreml_outputs[f"new_full_key_{i}"]) |
| full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(coreml_outputs[f"new_full_value_{i}"]) |
| conv_states = [torch.from_numpy(coreml_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| recurrent_states = [ |
| torch.from_numpy(coreml_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers) |
| ] |
| cache_len += 1 |
| current_token = torch.tensor([[next_token]], dtype=torch.long) |
| generated_coreml.append(next_token) |
|
|
| with torch.no_grad(): |
| native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True) |
| native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| native_cache = native.past_key_values |
| generated_native.append(int(native_token.item())) |
|
|
| report = { |
| "package": str(args.package.expanduser().resolve()), |
| "prompt_tokens": int(sample["input_ids"].shape[1]), |
| "steps": args.steps, |
| "coreml_tokens": generated_coreml, |
| "native_tokens": generated_native, |
| "prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b), |
| "all_tokens_match": generated_coreml == generated_native, |
| "coreml_text": processor.decode(generated_coreml, skip_special_tokens=False), |
| "native_text": processor.decode(generated_native, skip_special_tokens=False), |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def combined_runtime_smoke(args) -> None: |
| dtype = torch.float32 if args.dtype == "float32" else torch.float16 |
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = load_model(args.model_id, dtype) |
| sample, prefill_example = build_prefill_example(model, processor) |
|
|
| prefill_ml = ct.models.MLModel(str(args.prefill_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
| decode_ml = ct.models.MLModel(str(args.decode_package.expanduser().resolve()), compute_units=ct.ComputeUnit.CPU_ONLY) |
|
|
| prefill_specs = prefill_input_specs(prefill_example) |
| prefill_feed = { |
| spec.name: tensor.detach().cpu().numpy().astype(np.float32) |
| for spec, tensor in zip(prefill_specs, prefill_example, strict=True) |
| } |
| prefill_outputs = prefill_ml.predict(prefill_feed) |
| prefill_logits = prefill_outputs["logits"] |
| current_token = torch.tensor([[int(prefill_logits[:, -1, :].argmax(axis=-1)[0])]], dtype=torch.long) |
|
|
| full_keys = [torch.from_numpy(prefill_outputs[f"full_key_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] |
| full_values = [torch.from_numpy(prefill_outputs[f"full_value_{i}"]) for i in range(SuryaCoreMLPrefillFlat.full_layers)] |
| conv_states = [torch.from_numpy(prefill_outputs[f"conv_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers)] |
| recurrent_states = [ |
| torch.from_numpy(prefill_outputs[f"recurrent_state_{i}"]) for i in range(SuryaCoreMLPrefillFlat.linear_layers) |
| ] |
|
|
| with torch.no_grad(): |
| native_prefill = model(**sample, use_cache=True, return_dict=True) |
| native_token = native_prefill.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| native_cache = native_prefill.past_key_values |
|
|
| cache_len = int(sample["input_ids"].shape[1]) |
| generated_coreml = [int(current_token.item())] |
| generated_native = [int(native_token.item())] |
|
|
| for _ in range(args.steps): |
| with torch.no_grad(): |
| inputs_embeds = model.model.language_model.embed_tokens(current_token) |
| position_ids = torch.full((3, 1, 1), cache_len, dtype=torch.long) |
| cos, sin = model.model.language_model.rotary_emb(inputs_embeds, position_ids) |
| attention_mask = decode_attention_mask(cache_len, args.max_cache_length) |
| example = (inputs_embeds, cos, sin, attention_mask) + tuple(full_keys + full_values + conv_states + recurrent_states) |
| specs = decode_input_specs(example) |
| feed = {spec.name: tensor.detach().cpu().numpy().astype(np.float32) for spec, tensor in zip(specs, example, strict=True)} |
| decode_outputs = decode_ml.predict(feed) |
| logits = decode_outputs["logits"] |
| next_token = int(logits[:, -1, :].argmax(axis=-1)[0]) |
|
|
| if cache_len >= args.max_cache_length: |
| raise ValueError(f"cache_len={cache_len} reached max_cache_length={args.max_cache_length}") |
| for i in range(SuryaCoreMLDecodeStepFlat.full_layers): |
| full_keys[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_key_{i}"]) |
| full_values[i][:, :, cache_len : cache_len + 1, :] = torch.from_numpy(decode_outputs[f"new_full_value_{i}"]) |
| conv_states = [torch.from_numpy(decode_outputs[f"new_conv_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers)] |
| recurrent_states = [ |
| torch.from_numpy(decode_outputs[f"new_recurrent_state_{i}"]) for i in range(SuryaCoreMLDecodeStepFlat.linear_layers) |
| ] |
| cache_len += 1 |
| current_token = torch.tensor([[next_token]], dtype=torch.long) |
| generated_coreml.append(next_token) |
|
|
| with torch.no_grad(): |
| native = model(input_ids=native_token, past_key_values=native_cache, use_cache=True, return_dict=True) |
| native_token = native.logits[:, -1, :].argmax(dim=-1, keepdim=True).clone() |
| native_cache = native.past_key_values |
| generated_native.append(int(native_token.item())) |
|
|
| report = { |
| "prefill_package": str(args.prefill_package.expanduser().resolve()), |
| "decode_package": str(args.decode_package.expanduser().resolve()), |
| "prompt_tokens": int(sample["input_ids"].shape[1]), |
| "steps": args.steps, |
| "prefill_coreml_first_token": generated_coreml[0], |
| "prefill_native_first_token": generated_native[0], |
| "coreml_tokens": generated_coreml, |
| "native_tokens": generated_native, |
| "prefix_match_tokens": sum(1 for a, b in zip(generated_coreml, generated_native, strict=True) if a == b), |
| "all_tokens_match": generated_coreml == generated_native, |
| "coreml_text": processor.decode(generated_coreml, skip_special_tokens=False), |
| "native_text": processor.decode(generated_native, skip_special_tokens=False), |
| "host_responsibilities": [ |
| "tokenization", |
| "token embedding lookup for generated tokens", |
| "rotary position embedding generation", |
| "full-attention KV cache insertion", |
| ], |
| } |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(report, indent=2) + "\n", encoding="utf-8") |
| print(json.dumps(report, indent=2), flush=True) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Build and validate a faithful Surya CoreML OCR runtime scaffold.") |
| parser.add_argument("--model-id", default="datalab-to/surya-ocr-2") |
| parser.add_argument("--dtype", choices=["float32", "float16"], default="float32") |
| sub = parser.add_subparsers(dest="command", required=True) |
| check = sub.add_parser("check-decode-parity") |
| check.add_argument("--output", type=Path, required=True) |
| check.add_argument("--max-cache-length", type=int, default=512) |
| prefill_check = sub.add_parser("check-prefill-parity") |
| prefill_check.add_argument("--output", type=Path, required=True) |
| prefill_check.add_argument("--max-cache-length", type=int, default=512) |
| prefill_export = sub.add_parser("export-prefill") |
| prefill_export.add_argument("--output-dir", type=Path, required=True) |
| prefill_export.add_argument("--max-cache-length", type=int, default=512) |
| prefill_smoke = sub.add_parser("smoke-prefill") |
| prefill_smoke.add_argument("--package", type=Path, required=True) |
| prefill_smoke.add_argument("--output", type=Path, required=True) |
| prefill_smoke.add_argument("--max-cache-length", type=int, default=512) |
| export = sub.add_parser("export-decode-step") |
| export.add_argument("--output-dir", type=Path, required=True) |
| export.add_argument("--max-cache-length", type=int, default=512) |
| smoke = sub.add_parser("smoke-decode-step") |
| smoke.add_argument("--package", type=Path, required=True) |
| smoke.add_argument("--output", type=Path, required=True) |
| smoke.add_argument("--max-cache-length", type=int, default=512) |
| loop = sub.add_parser("iterative-decode-smoke") |
| loop.add_argument("--package", type=Path, required=True) |
| loop.add_argument("--output", type=Path, required=True) |
| loop.add_argument("--max-cache-length", type=int, default=512) |
| loop.add_argument("--steps", type=int, default=8) |
| combined = sub.add_parser("combined-runtime-smoke") |
| combined.add_argument("--prefill-package", type=Path, required=True) |
| combined.add_argument("--decode-package", type=Path, required=True) |
| combined.add_argument("--output", type=Path, required=True) |
| combined.add_argument("--max-cache-length", type=int, default=512) |
| combined.add_argument("--steps", type=int, default=8) |
| vision_combined = sub.add_parser("vision-combined-runtime-smoke") |
| vision_combined.add_argument("--vision-package", type=Path, required=True) |
| vision_combined.add_argument("--prefill-package", type=Path, required=True) |
| vision_combined.add_argument("--decode-package", type=Path, required=True) |
| vision_combined.add_argument("--output", type=Path, required=True) |
| vision_combined.add_argument("--max-cache-length", type=int, default=512) |
| vision_combined.add_argument("--steps", type=int, default=8) |
| args = parser.parse_args() |
| if args.command == "check-decode-parity": |
| parity_check(args) |
| elif args.command == "check-prefill-parity": |
| check_prefill_parity(args) |
| elif args.command == "export-prefill": |
| export_prefill(args) |
| elif args.command == "smoke-prefill": |
| smoke_prefill(args) |
| elif args.command == "export-decode-step": |
| export_decode_step(args) |
| elif args.command == "smoke-decode-step": |
| smoke_decode_step(args) |
| elif args.command == "iterative-decode-smoke": |
| iterative_decode_smoke(args) |
| elif args.command == "combined-runtime-smoke": |
| combined_runtime_smoke(args) |
| elif args.command == "vision-combined-runtime-smoke": |
| vision_combined_runtime_smoke(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|