| | import os |
| | import json |
| | import torch |
| | import onnx |
| | import argparse |
| | from PIL import Image |
| | from torch.onnx._globals import GLOBALS |
| | from transformers import ColPaliForRetrieval, ColPaliProcessor |
| | from optimum.onnx.graph_transformations import check_and_save_model |
| | import onnx_graphsurgeon as gs |
| | from onnxconverter_common import float16 |
| | from onnx.external_data_helper import convert_model_to_external_data |
| |
|
| |
|
| | def export_model( |
| | model_id: str, |
| | output_dir: str, |
| | device: str, |
| | fp16: bool = False, |
| | export_type: str = "both", |
| | ): |
| | """Export ColPaliForRetrieval to ONNX vision/text/both""" |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | model = ( |
| | ColPaliForRetrieval.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.float16 if fp16 else torch.float32, |
| | device_map="auto", |
| | ) |
| | .to(device) |
| | .eval() |
| | ) |
| | processor = ColPaliProcessor.from_pretrained(model_id) |
| | model.config.save_pretrained(output_dir) |
| | processor.save_pretrained(output_dir) |
| |
|
| | _orig_forward = model.forward |
| |
|
| | |
| | dummy_img = Image.new("RGB", (32, 32), color="white") |
| | vision_pt = processor(images=[dummy_img], return_tensors="pt").to(device) |
| | pv, ids, msk = ( |
| | vision_pt["pixel_values"], |
| | vision_pt["input_ids"], |
| | vision_pt["attention_mask"], |
| | ) |
| | fake_ids = torch.zeros((pv.size(0), 1), device=device, dtype=torch.long) |
| | fake_mask = torch.zeros_like(fake_ids, device=device) |
| | fake_pv = torch.zeros_like(pv) |
| |
|
| | out_paths = {} |
| |
|
| | |
| | if export_type in ("vision", "both"): |
| |
|
| | def vision_forward( |
| | self, pixel_values=None, input_ids=None, attention_mask=None, **kw |
| | ): |
| | return _orig_forward( |
| | pixel_values=pixel_values, |
| | input_ids=None, |
| | attention_mask=None, |
| | **kw, |
| | ).embeddings |
| |
|
| | model.forward = vision_forward.__get__(model, model.__class__) |
| |
|
| | vision_onnx = os.path.join(output_dir, "model_vision.onnx") |
| | vision_bin = "model_vision.onnx_data" |
| | GLOBALS.onnx_shape_inference = False |
| | torch.onnx.export( |
| | model, |
| | (pv, fake_ids, fake_mask), |
| | vision_onnx, |
| | export_params=True, |
| | opset_version=14, |
| | do_constant_folding=True, |
| | use_external_data_format=True, |
| | all_tensors_to_one_file=True, |
| | size_threshold=0, |
| | external_data_filename=vision_bin, |
| | input_names=["pixel_values", "input_ids", "attention_mask"], |
| | output_names=["embeddings"], |
| | dynamic_axes={ |
| | "pixel_values": {0: "batch_size"}, |
| | "embeddings": {0: "batch_size", 1: "seq_len"}, |
| | }, |
| | ) |
| | print("✅ Exported VISION ONNX to", vision_onnx) |
| |
|
| | |
| | m = onnx.shape_inference.infer_shapes_path(vision_onnx) |
| | m = onnx.load(vision_onnx, load_external_data=True) |
| | check_and_save_model(m, vision_onnx) |
| | print(" (shape‐inferred + external‐data fixed)") |
| |
|
| | out_paths["vision"] = vision_onnx |
| |
|
| | |
| | if export_type in ("text", "both"): |
| |
|
| | def text_forward( |
| | self, pixel_values=None, input_ids=None, attention_mask=None, **kw |
| | ): |
| | return _orig_forward( |
| | pixel_values=None, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **kw, |
| | ).embeddings |
| |
|
| | model.forward = text_forward.__get__(model, model.__class__) |
| |
|
| | text_onnx = os.path.join(output_dir, "model_text.onnx") |
| | text_bin = "model_text.onnx_data" |
| | torch.onnx.export( |
| | model, |
| | (fake_pv, ids, msk), |
| | text_onnx, |
| | export_params=True, |
| | opset_version=14, |
| | do_constant_folding=True, |
| | use_external_data_format=True, |
| | all_tensors_to_one_file=True, |
| | size_threshold=0, |
| | external_data_filename=text_bin, |
| | input_names=["pixel_values", "input_ids", "attention_mask"], |
| | output_names=["embeddings"], |
| | dynamic_axes={ |
| | "input_ids": {0: "batch_size", 1: "seq_len"}, |
| | "attention_mask": {0: "batch_size", 1: "seq_len"}, |
| | "embeddings": {0: "batch_size", 1: "seq_len"}, |
| | }, |
| | ) |
| | print("✅ Exported TEXT ONNX to", text_onnx) |
| |
|
| | m = onnx.shape_inference.infer_shapes_path(text_onnx) |
| | m = onnx.load(text_onnx, load_external_data=True) |
| | check_and_save_model(m, text_onnx) |
| | print(" (shape‐inferred + external‐data fixed)") |
| |
|
| | out_paths["text"] = text_onnx |
| |
|
| | print("🎉 Done exporting model(s):", out_paths) |
| | return out_paths |
| |
|
| |
|
| | def quantize_fp16_and_externalize( |
| | input_path, |
| | output_path, |
| | external_data_filename="model.onnx_data", |
| | op_block_list=None, |
| | ): |
| | """ |
| | Quantize an ONNX model from FP32 to FP16 |
| | 1) Load FP32 ONNX (+ its .onnx_data) |
| | 2) Cast weight tensors to FP16 |
| | 3) Topo-sort / clean up |
| | 4) Copy opset_import from original model |
| | 5) Mark ALL tensors for external data |
| | 6) Save the new ONNX + .onnx_data |
| | """ |
| | orig = onnx.load(input_path, load_external_data=True) |
| | model = onnx.load(input_path, load_external_data=True) |
| |
|
| | disable_si = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF |
| | blocked = set(float16.DEFAULT_OP_BLOCK_LIST) |
| | if op_block_list: |
| | blocked.update(op_block_list) |
| | blocked.update(["LayerNormalization", "Softmax", "Div"]) |
| |
|
| | model_fp16 = float16.convert_float_to_float16( |
| | model, |
| | max_finite_val=65504.0, |
| | keep_io_types=True, |
| | disable_shape_infer=disable_si, |
| | op_block_list=blocked, |
| | ) |
| |
|
| | graph = gs.import_onnx(model_fp16) |
| | graph.toposort() |
| | model_fp16 = gs.export_onnx(graph) |
| |
|
| | model_fp16.ClearField("opset_import") |
| | model_fp16.opset_import.extend(orig.opset_import) |
| |
|
| | convert_model_to_external_data( |
| | model_fp16, |
| | all_tensors_to_one_file=True, |
| | location=external_data_filename, |
| | size_threshold=0, |
| | ) |
| |
|
| | |
| | if not model_fp16.opset_import: |
| | model_fp16.opset_import.extend( |
| | [ |
| | onnx.helper.make_opsetid("", 14), |
| | ] |
| | ) |
| |
|
| | |
| | check_and_save_model(model_fp16, output_path) |
| |
|
| | print("✅ FP16 model quantized and saved:") |
| | print(f" ONNX: {output_path}") |
| | print( |
| | f" DATA: {os.path.join(os.path.dirname(output_path), external_data_filename)}" |
| | ) |
| | return True |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Convert ColPali model to ONNX format and FP16 quantization" |
| | ) |
| | parser.add_argument( |
| | "--model-id", default="vidore/colpali-v1.3-hf", help="HuggingFace model ID" |
| | ) |
| | parser.add_argument("--output-dir", default=None, help="Output directory") |
| | parser.add_argument( |
| | "--quantize", action="store_true", help="Apply FP16 quantization after export" |
| | ) |
| | parser.add_argument( |
| | "--export-type", |
| | choices=["vision", "text", "both"], |
| | default="both", |
| | help="Which ONNX to export", |
| | ) |
| | parser.add_argument("--device", default=None, help="Device for model (cuda/cpu)") |
| | args = parser.parse_args() |
| |
|
| | if args.device is None: |
| | args.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if args.output_dir is None: |
| | args.output_dir = os.path.join("output", args.model_id.replace("/", "_")) |
| |
|
| | out_paths = export_model( |
| | args.model_id, |
| | args.output_dir, |
| | args.device, |
| | fp16=False, |
| | export_type=args.export_type, |
| | ) |
| |
|
| | |
| | if args.quantize: |
| | print("Starting FP16 quantization") |
| | for key, path in out_paths.items(): |
| | binname = os.path.basename(path).replace(".onnx", ".onnx_data") |
| | quantize_fp16_and_externalize(path, path, binname) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|