Visual Document Retrieval
Transformers
ONNX
ColPali
English
pretraining
kitsuneb commited on
Commit
1221e23
·
1 Parent(s): 33206ff

small change for text

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. conversion/README.md +12 -3
  3. conversion/convert.py +110 -70
README.md CHANGED
@@ -19,6 +19,8 @@ pipeline_tag: visual-document-retrieval
19
 
20
  # ColPali: Visual Retriever based on PaliGemma-3B with ColBERT strategy
21
 
 
 
22
  ColPali is a model based on a novel model architecture and training strategy based on Vision Language Models (VLMs) to efficiently index documents from their visual features.
23
  It is a [PaliGemma-3B](https://huggingface.co/google/paligemma-3b-mix-448) extension that generates [ColBERT](https://arxiv.org/abs/2004.12832)- style multi-vector representations of text and images.
24
  It was introduced in the paper [ColPali: Efficient Document Retrieval with Vision Language Models](https://arxiv.org/abs/2407.01449) and first released in [this repository](https://github.com/ManuelFay/colpali)
 
19
 
20
  # ColPali: Visual Retriever based on PaliGemma-3B with ColBERT strategy
21
 
22
+ > Please read `conversion/readme.md` for details about the conversion process and notes.
23
+
24
  ColPali is a model based on a novel model architecture and training strategy based on Vision Language Models (VLMs) to efficiently index documents from their visual features.
25
  It is a [PaliGemma-3B](https://huggingface.co/google/paligemma-3b-mix-448) extension that generates [ColBERT](https://arxiv.org/abs/2004.12832)- style multi-vector representations of text and images.
26
  It was introduced in the paper [ColPali: Efficient Document Retrieval with Vision Language Models](https://arxiv.org/abs/2407.01449) and first released in [this repository](https://github.com/ManuelFay/colpali)
conversion/README.md CHANGED
@@ -1,6 +1,15 @@
1
  # ONNX Model Conversion Notes
2
 
3
  First of all, this was rather fun to do!
 
 
 
 
 
 
 
 
 
4
  The `convert.py` script is based on code I made on Google Colab in order to have access to a GPU.
5
  The `requirements.txt` might not be perfect, I'd much rather use UV which I use on a daily basis however this was created in Google colab in a fast manner.
6
 
@@ -9,7 +18,7 @@ Also note that I checked the output of the converted models and the original to
9
  - The fp32 (default ONNX) is nearly the same as the original HF model.
10
  - However, the FP16 converted ONNX model is not exactly the same, there is a margin of error.
11
 
12
- Below is a code snippet that showcases the comparison:
13
 
14
  ```python
15
  import torch
@@ -25,8 +34,8 @@ DEVICE = "cpu"
25
 
26
  hf = (
27
  ColPaliForRetrieval
28
- # NOTE change this to torch.float32 when we are comparing to ONNX fp32
29
- # same for fpt16 to make it fair comparison
30
  .from_pretrained(MODEL_ID, torch_dtype=torch.float16)
31
  .to(DEVICE)
32
  .eval()
 
1
  # ONNX Model Conversion Notes
2
 
3
  First of all, this was rather fun to do!
4
+
5
+ I figured out that it might not be so explicit that the convert script I made only applies to vision and you would do, almost, the exact same for image inputs. So I have extended it a tiny bit and left this comment. Especially since the intended use of Colpali is to run image embedding at offline time (getting your vector db ready) and the text model is intended for online (query) time.
6
+
7
+ However now Ive included that in the `convert.py` script as well, its not much of a change. Ive excluded uploading the text those model files since it is exactly the same process as the vision one, so results will be the same and uploading takes a long time with my home wifi unfortunately.
8
+
9
+ Ive opted for two models, in theory you could split up the image and text inputs into several graphs and call them in the correct order since they do share (some) weights for each input type. However given the intended use of Colpali offline/online case its not necassary and probably overkill for this exercise.
10
+
11
+
12
+ ## Some practical notes
13
  The `convert.py` script is based on code I made on Google Colab in order to have access to a GPU.
14
  The `requirements.txt` might not be perfect, I'd much rather use UV which I use on a daily basis however this was created in Google colab in a fast manner.
15
 
 
18
  - The fp32 (default ONNX) is nearly the same as the original HF model.
19
  - However, the FP16 converted ONNX model is not exactly the same, there is a margin of error.
20
 
21
+ Below is a code snippet that showcases the comparison for image input:
22
 
23
  ```python
24
  import torch
 
34
 
35
  hf = (
36
  ColPaliForRetrieval
37
+ # NOTE change this to torch.float16 when we are doing ONNX fp16
38
+ # Also change
39
  .from_pretrained(MODEL_ID, torch_dtype=torch.float16)
40
  .to(DEVICE)
41
  .eval()
conversion/convert.py CHANGED
@@ -12,11 +12,16 @@ from onnxconverter_common import float16
12
  from onnx.external_data_helper import convert_model_to_external_data
13
 
14
 
15
- def export_model(model_id, output_dir, device, fp16=False):
16
- """Export HuggingFace model ColPaliForRetrieval to ONNX format"""
 
 
 
 
 
 
17
  os.makedirs(output_dir, exist_ok=True)
18
 
19
- # Load HF model & processor
20
  model = (
21
  ColPaliForRetrieval.from_pretrained(
22
  model_id,
@@ -27,83 +32,118 @@ def export_model(model_id, output_dir, device, fp16=False):
27
  .eval()
28
  )
29
  processor = ColPaliProcessor.from_pretrained(model_id)
30
-
31
- # Save HF artifacts
32
  model.config.save_pretrained(output_dir)
33
  processor.save_pretrained(output_dir)
34
 
35
- # patched forward method
36
  _orig_forward = model.forward
37
 
38
- def _patched_forward(
39
- self, pixel_values=None, input_ids=None, attention_mask=None, **kwargs
40
- ):
41
- # Call the original .forward
42
- out = _orig_forward(
43
- pixel_values=pixel_values,
44
- input_ids=input_ids,
45
- attention_mask=attention_mask,
46
- **kwargs,
47
- )
48
- return out.embeddings
49
-
50
- model.forward = _patched_forward.__get__(model, model.__class__)
51
-
52
- # check with dummy image batch
53
  dummy_img = Image.new("RGB", (32, 32), color="white")
54
  vision_pt = processor(images=[dummy_img], return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- pv = vision_pt["pixel_values"]
57
- ids = vision_pt["input_ids"]
58
- msk = vision_pt["attention_mask"]
59
-
60
- with torch.no_grad():
61
- emb = model(pv, ids, msk)
62
- print("Sanity-check embedding shape:", emb.shape)
63
 
64
- # Export to ONNX + external data
65
- GLOBALS.onnx_shape_inference = False # Workaround shape bugs
66
- onnx_path = os.path.join(output_dir, "model.onnx")
67
- external_binfile = os.path.join(output_dir, "model.onnx_data")
68
 
69
- torch.onnx.export(
70
- model,
71
- (pv, ids, msk),
72
- onnx_path,
73
- export_params=True,
74
- opset_version=14,
75
- do_constant_folding=True,
76
- use_external_data_format=True,
77
- all_tensors_to_one_file=True,
78
- size_threshold=0,
79
- external_data_filename=os.path.basename(external_binfile),
80
- input_names=["pixel_values", "input_ids", "attention_mask"],
81
- output_names=["embeddings"],
82
- dynamic_axes={
83
- "pixel_values": {0: "batch_size"},
84
- "input_ids": {0: "batch_size", 1: "seq_len"},
85
- "attention_mask": {0: "batch_size", 1: "seq_len"},
86
- "embeddings": {0: "batch_size", 1: "seq_len"},
87
- },
88
- )
89
- print("Exported ONNX to", onnx_path)
90
-
91
- # Shape-infer & fix external-data refs
92
- onnx_model = onnx.shape_inference.infer_shapes_path(onnx_path)
93
- onnx_model = onnx.load(onnx_path)
94
- check_and_save_model(onnx_model, onnx_path)
95
- print("Shape-inference + external refs fixed")
96
-
97
- # Minify tokenizer.json
98
- tok = os.path.join(output_dir, "tokenizer.json")
99
- if os.path.isfile(tok):
100
- data = json.load(open(tok))
101
- with open(tok, "w") as f:
102
- json.dump(data, f, separators=(",", ":"))
103
- print("✔ Minified tokenizer.json")
104
-
105
- print("✅ ONNX + HF artifacts exported to", output_dir)
106
- return onnx_path
107
 
108
 
109
  def quantize_fp16_and_externalize(
 
12
  from onnx.external_data_helper import convert_model_to_external_data
13
 
14
 
15
+ def export_model(
16
+ model_id: str,
17
+ output_dir: str,
18
+ device: str,
19
+ fp16: bool = False,
20
+ export_type: str = "both",
21
+ ):
22
+ """Export ColPaliForRetrieval to ONNX vision/text/both"""
23
  os.makedirs(output_dir, exist_ok=True)
24
 
 
25
  model = (
26
  ColPaliForRetrieval.from_pretrained(
27
  model_id,
 
32
  .eval()
33
  )
34
  processor = ColPaliProcessor.from_pretrained(model_id)
 
 
35
  model.config.save_pretrained(output_dir)
36
  processor.save_pretrained(output_dir)
37
 
 
38
  _orig_forward = model.forward
39
 
40
+ #dummy inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  dummy_img = Image.new("RGB", (32, 32), color="white")
42
  vision_pt = processor(images=[dummy_img], return_tensors="pt").to(device)
43
+ pv, ids, msk = (
44
+ vision_pt["pixel_values"],
45
+ vision_pt["input_ids"],
46
+ vision_pt["attention_mask"],
47
+ )
48
+ fake_ids = torch.zeros((pv.size(0), 1), device=device, dtype=torch.long)
49
+ fake_mask = torch.zeros_like(fake_ids, device=device)
50
+ fake_pv = torch.zeros_like(pv)
51
+
52
+ out_paths = {}
53
+
54
+ # vision model
55
+ if export_type in ("vision", "both"):
56
+
57
+ def vision_forward(
58
+ self, pixel_values=None, input_ids=None, attention_mask=None, **kw
59
+ ):
60
+ return _orig_forward(
61
+ pixel_values=pixel_values,
62
+ input_ids=None,
63
+ attention_mask=None,
64
+ **kw,
65
+ ).embeddings
66
+
67
+ model.forward = vision_forward.__get__(model, model.__class__)
68
+
69
+ vision_onnx = os.path.join(output_dir, "model_vision.onnx")
70
+ vision_bin = "model_vision.onnx_data"
71
+ GLOBALS.onnx_shape_inference = False
72
+ torch.onnx.export(
73
+ model,
74
+ (pv, fake_ids, fake_mask),
75
+ vision_onnx,
76
+ export_params=True,
77
+ opset_version=14,
78
+ do_constant_folding=True,
79
+ use_external_data_format=True,
80
+ all_tensors_to_one_file=True,
81
+ size_threshold=0,
82
+ external_data_filename=vision_bin,
83
+ input_names=["pixel_values", "input_ids", "attention_mask"],
84
+ output_names=["embeddings"],
85
+ dynamic_axes={
86
+ "pixel_values": {0: "batch_size"},
87
+ "embeddings": {0: "batch_size", 1: "seq_len"},
88
+ },
89
+ )
90
+ print("✅ Exported VISION ONNX to", vision_onnx)
91
+
92
+ # fix shapes & external refs
93
+ m = onnx.shape_inference.infer_shapes_path(vision_onnx)
94
+ m = onnx.load(vision_onnx, load_external_data=True)
95
+ check_and_save_model(m, vision_onnx)
96
+ print(" (shape‐inferred + external‐data fixed)")
97
+
98
+ out_paths["vision"] = vision_onnx
99
+
100
+ # text model
101
+ if export_type in ("text", "both"):
102
+
103
+ def text_forward(
104
+ self, pixel_values=None, input_ids=None, attention_mask=None, **kw
105
+ ):
106
+ return _orig_forward(
107
+ pixel_values=None,
108
+ input_ids=input_ids,
109
+ attention_mask=attention_mask,
110
+ **kw,
111
+ ).embeddings
112
+
113
+ model.forward = text_forward.__get__(model, model.__class__)
114
+
115
+ text_onnx = os.path.join(output_dir, "model_text.onnx")
116
+ text_bin = "model_text.onnx_data"
117
+ torch.onnx.export(
118
+ model,
119
+ (fake_pv, ids, msk),
120
+ text_onnx,
121
+ export_params=True,
122
+ opset_version=14,
123
+ do_constant_folding=True,
124
+ use_external_data_format=True,
125
+ all_tensors_to_one_file=True,
126
+ size_threshold=0,
127
+ external_data_filename=text_bin,
128
+ input_names=["pixel_values", "input_ids", "attention_mask"],
129
+ output_names=["embeddings"],
130
+ dynamic_axes={
131
+ "input_ids": {0: "batch_size", 1: "seq_len"},
132
+ "attention_mask": {0: "batch_size", 1: "seq_len"},
133
+ "embeddings": {0: "batch_size", 1: "seq_len"},
134
+ },
135
+ )
136
+ print("✅ Exported TEXT ONNX to", text_onnx)
137
 
138
+ m = onnx.shape_inference.infer_shapes_path(text_onnx)
139
+ m = onnx.load(text_onnx, load_external_data=True)
140
+ check_and_save_model(m, text_onnx)
141
+ print(" (shape‐inferred + external‐data fixed)")
 
 
 
142
 
143
+ out_paths["text"] = text_onnx
 
 
 
144
 
145
+ print("🎉 Done exporting model(s):", out_paths)
146
+ return out_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
 
149
  def quantize_fp16_and_externalize(