#!/usr/bin/env python3 import argparse import importlib.util import json from pathlib import Path import numpy as np import torch from transformers import AutoProcessor def load_runtime_module(path: Path): spec = importlib.util.spec_from_file_location("surya_coreml_runtime_export", path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def write_fp16(path: Path, tensor: torch.Tensor) -> None: array = tensor.detach().cpu().to(torch.float16).numpy() path.parent.mkdir(parents=True, exist_ok=True) path.write_bytes(array.tobytes(order="C")) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model-id", default="datalab-to/surya-ocr-2") parser.add_argument("--runtime-script", type=Path, default=Path(".context/scripts/export_surya_coreml_runtime.py")) parser.add_argument("--output-dir", type=Path, default=Path("artifacts/coreml/surya-ocr-2-coreml-8bit/native_assets")) parser.add_argument("--max-cache-length", type=int, default=512) args = parser.parse_args() rt = load_runtime_module(args.runtime_script) output_dir = args.output_dir.expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) model = rt.load_model(args.model_id, torch.float32) sample = rt.build_sample(processor) with torch.no_grad(): input_ids = sample["input_ids"].to(torch.long) attention_mask = sample["attention_mask"].to(torch.long) mm_token_type_ids = sample["mm_token_type_ids"].to(torch.long) pixel_values = sample["pixel_values"].to(next(model.parameters()).dtype) image_grid_thw = sample["image_grid_thw"].to(torch.long) text_embeds = model.model.get_input_embeddings()(input_ids) image_outputs = model.model.get_image_features(pixel_values, image_grid_thw, return_dict=True) image_embeds = torch.cat(image_outputs.pooler_output, dim=0).to(text_embeds.dtype) image_mask, _ = model.model.get_placeholder_mask(input_ids, inputs_embeds=text_embeds, image_features=image_embeds) image_token_indices = image_mask.squeeze(0).squeeze(-1).nonzero().flatten().to(torch.long) prefill_base = text_embeds.clone() prefill_base[:, image_token_indices, :] = 0 merged_embeds = text_embeds.clone() merged_embeds[:, image_token_indices, :] = image_embeds.reshape(1, image_embeds.shape[0], image_embeds.shape[1]) position_ids = model.model.compute_3d_position_ids( input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=None, inputs_embeds=merged_embeds, attention_mask=attention_mask, past_key_values=None, mm_token_type_ids=mm_token_type_ids, ) prefill_cos, prefill_sin = model.model.language_model.rotary_emb(merged_embeds, position_ids) one_embed = torch.zeros((1, 1, merged_embeds.shape[-1]), dtype=merged_embeds.dtype) step_cos = [] step_sin = [] for pos in range(args.max_cache_length): step_position_ids = torch.full((3, 1, 1), pos, dtype=torch.long) cos, sin = model.model.language_model.rotary_emb(one_embed, step_position_ids) step_cos.append(cos) step_sin.append(sin) step_cos = torch.cat(step_cos, dim=1) step_sin = torch.cat(step_sin, dim=1) token_embedding = model.model.get_input_embeddings().weight.detach().clone() write_fp16(output_dir / "prefill_text_embeds_base_fp16.bin", prefill_base) write_fp16(output_dir / "prefill_cos_fp16.bin", prefill_cos) write_fp16(output_dir / "prefill_sin_fp16.bin", prefill_sin) write_fp16(output_dir / "decode_cos_fp16.bin", step_cos) write_fp16(output_dir / "decode_sin_fp16.bin", step_sin) write_fp16(output_dir / "token_embedding_fp16.bin", token_embedding) write_fp16(output_dir / "canary_pixel_values_fp16.bin", pixel_values) constants = { "model_id": args.model_id, "prompt": rt.PROMPT, "input_ids": input_ids.squeeze(0).tolist(), "attention_mask": attention_mask.squeeze(0).tolist(), "mm_token_type_ids": mm_token_type_ids.squeeze(0).tolist(), "image_grid_thw": image_grid_thw.squeeze(0).tolist(), "image_token_indices": image_token_indices.tolist(), "shapes": { "prefill_text_embeds_base": list(prefill_base.shape), "prefill_cos": list(prefill_cos.shape), "prefill_sin": list(prefill_sin.shape), "decode_cos": list(step_cos.shape), "decode_sin": list(step_sin.shape), "token_embedding": list(token_embedding.shape), "pixel_values": list(pixel_values.shape), "image_embeds": list(image_embeds.shape), }, "dtype": "float16 little-endian raw binaries", "files": { "prefill_text_embeds_base": "prefill_text_embeds_base_fp16.bin", "prefill_cos": "prefill_cos_fp16.bin", "prefill_sin": "prefill_sin_fp16.bin", "decode_cos": "decode_cos_fp16.bin", "decode_sin": "decode_sin_fp16.bin", "token_embedding": "token_embedding_fp16.bin", "canary_pixel_values": "canary_pixel_values_fp16.bin", }, } (output_dir / "surya_native_constants.json").write_text(json.dumps(constants, indent=2) + "\n", encoding="utf-8") vocab = processor.tokenizer.get_vocab() id_to_token = [""] * (max(vocab.values()) + 1) for token, idx in vocab.items(): id_to_token[idx] = token (output_dir / "id_to_token.json").write_text(json.dumps(id_to_token, ensure_ascii=False) + "\n", encoding="utf-8") print(json.dumps({"output_dir": str(output_dir), "files": sorted(p.name for p in output_dir.iterdir())}, indent=2)) if __name__ == "__main__": main()