| |
| import argparse |
| import importlib.util |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from transformers import AutoProcessor |
|
|
|
|
| def load_runtime_module(path: Path): |
| spec = importlib.util.spec_from_file_location("surya_coreml_runtime_export", path) |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def write_fp16(path: Path, tensor: torch.Tensor) -> None: |
| array = tensor.detach().cpu().to(torch.float16).numpy() |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_bytes(array.tobytes(order="C")) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-id", default="datalab-to/surya-ocr-2") |
| parser.add_argument("--runtime-script", type=Path, default=Path(".context/scripts/export_surya_coreml_runtime.py")) |
| parser.add_argument("--output-dir", type=Path, default=Path("artifacts/coreml/surya-ocr-2-coreml-8bit/native_assets")) |
| parser.add_argument("--max-cache-length", type=int, default=512) |
| args = parser.parse_args() |
|
|
| rt = load_runtime_module(args.runtime_script) |
| output_dir = args.output_dir.expanduser().resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) |
| model = rt.load_model(args.model_id, torch.float32) |
| sample = rt.build_sample(processor) |
|
|
| with torch.no_grad(): |
| input_ids = sample["input_ids"].to(torch.long) |
| attention_mask = sample["attention_mask"].to(torch.long) |
| mm_token_type_ids = sample["mm_token_type_ids"].to(torch.long) |
| pixel_values = sample["pixel_values"].to(next(model.parameters()).dtype) |
| image_grid_thw = sample["image_grid_thw"].to(torch.long) |
|
|
| text_embeds = model.model.get_input_embeddings()(input_ids) |
| image_outputs = model.model.get_image_features(pixel_values, image_grid_thw, return_dict=True) |
| image_embeds = torch.cat(image_outputs.pooler_output, dim=0).to(text_embeds.dtype) |
| image_mask, _ = model.model.get_placeholder_mask(input_ids, inputs_embeds=text_embeds, image_features=image_embeds) |
| image_token_indices = image_mask.squeeze(0).squeeze(-1).nonzero().flatten().to(torch.long) |
| prefill_base = text_embeds.clone() |
| prefill_base[:, image_token_indices, :] = 0 |
|
|
| merged_embeds = text_embeds.clone() |
| merged_embeds[:, image_token_indices, :] = image_embeds.reshape(1, image_embeds.shape[0], image_embeds.shape[1]) |
| position_ids = model.model.compute_3d_position_ids( |
| input_ids=input_ids, |
| image_grid_thw=image_grid_thw, |
| video_grid_thw=None, |
| inputs_embeds=merged_embeds, |
| attention_mask=attention_mask, |
| past_key_values=None, |
| mm_token_type_ids=mm_token_type_ids, |
| ) |
| prefill_cos, prefill_sin = model.model.language_model.rotary_emb(merged_embeds, position_ids) |
|
|
| one_embed = torch.zeros((1, 1, merged_embeds.shape[-1]), dtype=merged_embeds.dtype) |
| step_cos = [] |
| step_sin = [] |
| for pos in range(args.max_cache_length): |
| step_position_ids = torch.full((3, 1, 1), pos, dtype=torch.long) |
| cos, sin = model.model.language_model.rotary_emb(one_embed, step_position_ids) |
| step_cos.append(cos) |
| step_sin.append(sin) |
| step_cos = torch.cat(step_cos, dim=1) |
| step_sin = torch.cat(step_sin, dim=1) |
|
|
| token_embedding = model.model.get_input_embeddings().weight.detach().clone() |
|
|
| write_fp16(output_dir / "prefill_text_embeds_base_fp16.bin", prefill_base) |
| write_fp16(output_dir / "prefill_cos_fp16.bin", prefill_cos) |
| write_fp16(output_dir / "prefill_sin_fp16.bin", prefill_sin) |
| write_fp16(output_dir / "decode_cos_fp16.bin", step_cos) |
| write_fp16(output_dir / "decode_sin_fp16.bin", step_sin) |
| write_fp16(output_dir / "token_embedding_fp16.bin", token_embedding) |
| write_fp16(output_dir / "canary_pixel_values_fp16.bin", pixel_values) |
|
|
| constants = { |
| "model_id": args.model_id, |
| "prompt": rt.PROMPT, |
| "input_ids": input_ids.squeeze(0).tolist(), |
| "attention_mask": attention_mask.squeeze(0).tolist(), |
| "mm_token_type_ids": mm_token_type_ids.squeeze(0).tolist(), |
| "image_grid_thw": image_grid_thw.squeeze(0).tolist(), |
| "image_token_indices": image_token_indices.tolist(), |
| "shapes": { |
| "prefill_text_embeds_base": list(prefill_base.shape), |
| "prefill_cos": list(prefill_cos.shape), |
| "prefill_sin": list(prefill_sin.shape), |
| "decode_cos": list(step_cos.shape), |
| "decode_sin": list(step_sin.shape), |
| "token_embedding": list(token_embedding.shape), |
| "pixel_values": list(pixel_values.shape), |
| "image_embeds": list(image_embeds.shape), |
| }, |
| "dtype": "float16 little-endian raw binaries", |
| "files": { |
| "prefill_text_embeds_base": "prefill_text_embeds_base_fp16.bin", |
| "prefill_cos": "prefill_cos_fp16.bin", |
| "prefill_sin": "prefill_sin_fp16.bin", |
| "decode_cos": "decode_cos_fp16.bin", |
| "decode_sin": "decode_sin_fp16.bin", |
| "token_embedding": "token_embedding_fp16.bin", |
| "canary_pixel_values": "canary_pixel_values_fp16.bin", |
| }, |
| } |
| (output_dir / "surya_native_constants.json").write_text(json.dumps(constants, indent=2) + "\n", encoding="utf-8") |
| vocab = processor.tokenizer.get_vocab() |
| id_to_token = [""] * (max(vocab.values()) + 1) |
| for token, idx in vocab.items(): |
| id_to_token[idx] = token |
| (output_dir / "id_to_token.json").write_text(json.dumps(id_to_token, ensure_ascii=False) + "\n", encoding="utf-8") |
| print(json.dumps({"output_dir": str(output_dir), "files": sorted(p.name for p in output_dir.iterdir())}, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|