surya-ocr-2-coreml-runtime / scripts /materialize_surya_native_assets.py
Reza2kn's picture
Add native Swift CoreML runtime and assets
d82e1f6 verified
Raw
History Blame Contribute Delete
5.98 kB
#!/usr/bin/env python3
import argparse
import importlib.util
import json
from pathlib import Path
import numpy as np
import torch
from transformers import AutoProcessor
def load_runtime_module(path: Path):
spec = importlib.util.spec_from_file_location("surya_coreml_runtime_export", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def write_fp16(path: Path, tensor: torch.Tensor) -> None:
array = tensor.detach().cpu().to(torch.float16).numpy()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(array.tobytes(order="C"))
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", default="datalab-to/surya-ocr-2")
parser.add_argument("--runtime-script", type=Path, default=Path(".context/scripts/export_surya_coreml_runtime.py"))
parser.add_argument("--output-dir", type=Path, default=Path("artifacts/coreml/surya-ocr-2-coreml-8bit/native_assets"))
parser.add_argument("--max-cache-length", type=int, default=512)
args = parser.parse_args()
rt = load_runtime_module(args.runtime_script)
output_dir = args.output_dir.expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
model = rt.load_model(args.model_id, torch.float32)
sample = rt.build_sample(processor)
with torch.no_grad():
input_ids = sample["input_ids"].to(torch.long)
attention_mask = sample["attention_mask"].to(torch.long)
mm_token_type_ids = sample["mm_token_type_ids"].to(torch.long)
pixel_values = sample["pixel_values"].to(next(model.parameters()).dtype)
image_grid_thw = sample["image_grid_thw"].to(torch.long)
text_embeds = model.model.get_input_embeddings()(input_ids)
image_outputs = model.model.get_image_features(pixel_values, image_grid_thw, return_dict=True)
image_embeds = torch.cat(image_outputs.pooler_output, dim=0).to(text_embeds.dtype)
image_mask, _ = model.model.get_placeholder_mask(input_ids, inputs_embeds=text_embeds, image_features=image_embeds)
image_token_indices = image_mask.squeeze(0).squeeze(-1).nonzero().flatten().to(torch.long)
prefill_base = text_embeds.clone()
prefill_base[:, image_token_indices, :] = 0
merged_embeds = text_embeds.clone()
merged_embeds[:, image_token_indices, :] = image_embeds.reshape(1, image_embeds.shape[0], image_embeds.shape[1])
position_ids = model.model.compute_3d_position_ids(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
inputs_embeds=merged_embeds,
attention_mask=attention_mask,
past_key_values=None,
mm_token_type_ids=mm_token_type_ids,
)
prefill_cos, prefill_sin = model.model.language_model.rotary_emb(merged_embeds, position_ids)
one_embed = torch.zeros((1, 1, merged_embeds.shape[-1]), dtype=merged_embeds.dtype)
step_cos = []
step_sin = []
for pos in range(args.max_cache_length):
step_position_ids = torch.full((3, 1, 1), pos, dtype=torch.long)
cos, sin = model.model.language_model.rotary_emb(one_embed, step_position_ids)
step_cos.append(cos)
step_sin.append(sin)
step_cos = torch.cat(step_cos, dim=1)
step_sin = torch.cat(step_sin, dim=1)
token_embedding = model.model.get_input_embeddings().weight.detach().clone()
write_fp16(output_dir / "prefill_text_embeds_base_fp16.bin", prefill_base)
write_fp16(output_dir / "prefill_cos_fp16.bin", prefill_cos)
write_fp16(output_dir / "prefill_sin_fp16.bin", prefill_sin)
write_fp16(output_dir / "decode_cos_fp16.bin", step_cos)
write_fp16(output_dir / "decode_sin_fp16.bin", step_sin)
write_fp16(output_dir / "token_embedding_fp16.bin", token_embedding)
write_fp16(output_dir / "canary_pixel_values_fp16.bin", pixel_values)
constants = {
"model_id": args.model_id,
"prompt": rt.PROMPT,
"input_ids": input_ids.squeeze(0).tolist(),
"attention_mask": attention_mask.squeeze(0).tolist(),
"mm_token_type_ids": mm_token_type_ids.squeeze(0).tolist(),
"image_grid_thw": image_grid_thw.squeeze(0).tolist(),
"image_token_indices": image_token_indices.tolist(),
"shapes": {
"prefill_text_embeds_base": list(prefill_base.shape),
"prefill_cos": list(prefill_cos.shape),
"prefill_sin": list(prefill_sin.shape),
"decode_cos": list(step_cos.shape),
"decode_sin": list(step_sin.shape),
"token_embedding": list(token_embedding.shape),
"pixel_values": list(pixel_values.shape),
"image_embeds": list(image_embeds.shape),
},
"dtype": "float16 little-endian raw binaries",
"files": {
"prefill_text_embeds_base": "prefill_text_embeds_base_fp16.bin",
"prefill_cos": "prefill_cos_fp16.bin",
"prefill_sin": "prefill_sin_fp16.bin",
"decode_cos": "decode_cos_fp16.bin",
"decode_sin": "decode_sin_fp16.bin",
"token_embedding": "token_embedding_fp16.bin",
"canary_pixel_values": "canary_pixel_values_fp16.bin",
},
}
(output_dir / "surya_native_constants.json").write_text(json.dumps(constants, indent=2) + "\n", encoding="utf-8")
vocab = processor.tokenizer.get_vocab()
id_to_token = [""] * (max(vocab.values()) + 1)
for token, idx in vocab.items():
id_to_token[idx] = token
(output_dir / "id_to_token.json").write_text(json.dumps(id_to_token, ensure_ascii=False) + "\n", encoding="utf-8")
print(json.dumps({"output_dir": str(output_dir), "files": sorted(p.name for p in output_dir.iterdir())}, indent=2))
if __name__ == "__main__":
main()