Prince-1 commited on
Commit
b7dca43
·
verified ·
1 Parent(s): 8f6bd14

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ ex1.png filter=lfs diff=lfs merge=lfs -text
38
+ *.rkllm filter=lfs diff=lfs merge=lfs -text
NuMarkdown-8B-Thinking.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:558b96520c8ca9037fc7695c26886cffcf7498aacc47f16642be7854aa05ae1a
3
+ size 15272523868
README.md ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model: NuMarkdown-8B-Thinking
4
+ tags:
5
+ - OCR
6
+ - vision-language
7
+ - VLM
8
+ - Reasoning
9
+ - document-to-markdown
10
+ - qwen2.5
11
+ - markdown
12
+ - extraction
13
+ - RAG
14
+ - rkllm
15
+ - onnx
16
+ model_name: NuMarkdown-8B-Thinking
17
+ library_name: transformers
18
+ pipeline_tag: image-to-text
19
+ ---
20
+
21
+ <p align="center">
22
+ <a href="https://nuextract.ai/">
23
+ <img src="numind.svg" width="400" height="400"/>
24
+ </a>
25
+ </p>
26
+ <p align="center">
27
+ 🖥️ <a href="https://nuextract.ai/">API / Platform</a>&nbsp&nbsp | &nbsp&nbsp🗣️ <a href="https://discord.gg/3tsEtJNCDe">Discord</a>&nbsp&nbsp | &nbsp&nbsp🔗 <a href="https://github.com/numindai/NuMarkdown">GitHub</a>&nbsp&nbsp | &nbsp&nbsp🤗 <a href="https://huggingface.co/spaces/numind/NuMarkdown-8b-Thinking">Demo</a>
28
+ </p>
29
+
30
+ ---
31
+
32
+ # Reasoning comes to OCR 🧠✨📄🤘
33
+
34
+ **NuMarkdown-8B-Thinking** is the first reasoning OCR VLM. It is specifically trained to convert documents into clean Markdown files, well suited for RAG applications. It generates thinking tokens to figure out the layout of the document before generating the Markdown file.
35
+ It is particularly good at understanding documents with weird layouts and complex tables. The number of thinking tokens can vary from 20% to 500% of the final answer, depending on the task difficulty.
36
+
37
+ **NuMarkdown-8B-Thinking** is a fine-tune of **Qwen 2.5-VL-7B** on synthetic Doc &rarr; Reasoning &rarr; Markdown examples, followed by an RL phase (GRPO) with a layout-centric reward.
38
+
39
+ Try it out in [the 🤗 space!](https://huggingface.co/spaces/numind/NuMarkdown-8b-Thinking)
40
+
41
+ ## Results
42
+
43
+ **NuMarkdown-8B-Thinking** is outperforming generic non-reasoning models like GPT-4o and specialized OCR models like OCRFlux.
44
+ It is competitive against large reasoning closed-source models like Gemini 2.5.
45
+
46
+ ### Arena ranking against popular alternatives (using trueskill-2 ranking system, with around 500 model-anonymized votes):
47
+ <p align="center">
48
+
49
+ | Rank | Model | μ | σ | μ − 3σ |
50
+ | ---- | --------------------------------------- | ----- | ---- | ------ |
51
+ | 🥇 1 | **gemini-flash-reasoning** | 26.75 | 0.80 | 24.35 |
52
+ | 🥈 2 | **NuMarkdown-reasoning** | 26.10 | 0.79 | 23.72 |
53
+ | 🥉 3 | **NuMarkdown-reasoning-w/o\_grpo** | 25.32 | 0.80 | 22.93 |
54
+ | 4 | **OCRFlux-3B** | 24.63 | 0.80 | 22.22 |
55
+ | 5 | **gpt-4o** | 24.48 | 0.80 | 22.08 |
56
+ | 6 | **gemini-flash-w/o\_reasoning** | 24.11 | 0.79 | 21.74 |
57
+ | 7 | **RolmoOCR** | 23.53 | 0.82 | 21.07 |
58
+
59
+ </p>
60
+
61
+ *We plan to realease a markdown arena, similar to llmArena, for complex document-to-markdown tasks to provide a tool to evaluate different solutions.*
62
+
63
+ ### Win/Draw/Lose-rate against others models (image-only):
64
+ <p align="center">
65
+ <img src="bar plot.png" width="700"/>
66
+ </p>
67
+
68
+
69
+ ## Training
70
+
71
+ 1. **SFT**: Single epoch supervised fine-tuning on synthetic reasoning traces generated from public PDFs.
72
+ 2. **RL (GRPO)**: RL phase using a layout-centric reward with difficult image examples.
73
+
74
+ ## Example:
75
+
76
+ <p align="center">
77
+ <img src="ex1.png" width="500"/>
78
+ </p>
79
+
80
+ ```
81
+ <think>
82
+ 1. **Analyze the overall structure:** The document consists of two main sections, each containing a header, a set of bullet points, a title, a table, and a footer. The layout is consistent within each section but different between the two.
83
+
84
+ 2. **Handle the first section header and bullet points:**
85
+ - The top left corner has "Generalitat de Catalunya Departament d'Educació Institut Gal·lecs". This is a consistent header and should be transcribed as a level 3 heading.
86
+ - The top right corner has "Curs 2021-22". This is a distinct header and should be transcribed as a level 3 heading.
87
+ - Below the header, there are four bullet points. These are standard list items. I will use Markdown's unordered list syntax (`- `).
88
+
89
+ 3. **Process the first section's main content - Table 1:**
90
+ - **Identify table structure:** This is a complex table with a clear header row: "MATERIAL", "TÍTOL", "AUTOR", "EDITORIAL", "ISBN".
91
+ - **Determine columns:** There are 5 columns.
92
+ - **Handle merged cells:**
93
+ - "Llengua Catalana" spans two rows.
94
+ - "Llengua Castellana" spans three rows.
95
+ - "Anglès" spans two rows.
96
+ - **Populate table data:** Transcribe all text, including multi-line content within cells (e.g., "Babulinka Books\nEditorialCol·lecció: Pettes\nJoies per a Grans Lectors.\n2013").
97
+ - **Markdown table creation:** Use standard Markdown table syntax, ensuring correct alignment and handling of multi-line content within cells.
98
+
99
+ 4. **Handle the first section's footer:**
100
+ - The footer contains "Codí: 04mp02", "Responsable: Coordinador Qualitat", "Versió: 5", "Full d'Informació a l'alumnat i famílies", "Aquest document pot quedar obsolet una vegada imprès", and "Pàgina 1 de 2". This is a simple text block at the bottom. I will transcribe it as a separate paragraph or list, maintaining the original layout if possible, but a simple paragraph is sufficient for a footer.
101
+
102
+ 5. **Handle the second section header and bullet points:**
103
+ - The top left corner has "Generalitat de Catalunya Departament d'Educació Institut Gal·lecs". This is a consistent header and should be transcribed as a level 3 heading.
104
+ - The top right corner has "Curs 2021-22". This is a distinct header and should be transcribed as a level 3 heading.
105
+ - Below the header, there are no bullet points, but there is a table.
106
+
107
+ 6. **Process the second section's main content - Table 2:**
108
+ - **Identify table structure:** This is a simpler table with a clear header row: "MATERIAL", "TÍTOL", "AUTOR", "EDITORIAL", "ISBN".
109
+ - **Determine columns:** There are 5 columns.
110
+ - **Populate table data:** Transcribe all text, including multi-line content within cells (e.g., "Projecte MOTIMATS-ONMAT (llicència digital) Tekman Books\nES GESTIONA A TRAVÉS DE L'AMPA AL SETEMBRE").
111
+ - **Markdown table creation:** Use standard Markdown table syntax.
112
+
113
+ 7. **Handle the second section's footer:**
114
+ - The footer contains "Codí: 04mp02", "Responsable: Coordinador Qualitat", "Versió: 5", "Full d'Informació a l'alumnat i famílies", "Aquest document pot quedar obsolet una vegada imprès", and "Pàgina 2 de 2". This is a simple text block at the bottom. I will transcribe it as a separate paragraph or list, maintaining the original layout if possible, but a simple paragraph is sufficient for a footer.
115
+
116
+ 8. **Final review:** Check for any missed text, incorrect formatting, or issues with table alignment. Ensure all original content is present and accurately represented in GitHub-Flavored Markdown. The image placeholders are not needed as there are no images in the document.
117
+ </think>
118
+ <answer>
119
+ ### Generalitat de Catalunya
120
+ ### Departament d'Educació
121
+ ### Institut Gal·lecs
122
+
123
+ ### Curs 2021-22
124
+
125
+ - Els llibres de color blau indiquen que es manté respecte al curs anterior.
126
+ - Els llibres de color groc indiquen que es tracta d'un canvi per a aquest curs.
127
+ - Els llibres de color vermell indiquen que no s'han de comprar perquè van a càrrec del centre.
128
+ - Matèries sense llibre: Dossiers, llicències i el Kit de tecnologia s'hauran de pagar i de recollir a través de l'AMPA a partir de setembre.
129
+
130
+ # 1 ESO
131
+
132
+ | MATERIAL | TÍTOL | AUTOR | EDITORIAL | ISBN |
133
+ |---|---|---|---|---|
134
+ | Llengua Catalana | Punt Volat | | Castellnou (Didacta +) | 9788417803124 |
135
+ | | Duna, Diari d'un estiu. | Muriel Villanueva | Babulinka Books<br>EditorialCol·lecció: Pettes<br>Joies per a Grans Lectors.<br>2013 | 9788494159077 |
136
+ | | El nen que xatejava amb Jack Sparrow. | Francesc Puigpelat | Bromera<br>Col·lecció: L'Elefant. 2015 | 9788490264072 |
137
+ | Llengua Castellana | Proyecto Asterisco | | Castellnou (Didacta +) | 9788417803186 |
138
+ | | Manzanas rojas | Luis Matilla | Ed. Anaya | 978846673989 |
139
+ | | Fàbulas de Esopo | Jerry Pinkney | Vicens Vives | 978843671648 |
140
+ | Anglès | Think Ahead ESO 1. Student's book.<br>Think Ahead ESO 1. Workbook (cat). | | Burlington Books<br>Burlington Books | 9788925300662<br>9789925300686 |
141
+
142
+ Codí: 04mp02
143
+ Responsable: Coordinador Qualitat
144
+ Versió: 5
145
+ Full d'Informació a l'alumnat i famílies
146
+ Aquest document pot quedar obsolet una vegada imprès
147
+ Pàgina 1 de 2
148
+
149
+ ### Generalitat de Catalunya
150
+ ### Departament d'Educació
151
+ ### Institut Gal·lecs
152
+
153
+ ### Curs 2021-22
154
+
155
+ | MATERIAL | TÍTOL | AUTOR | EDITORIAL | ISBN |
156
+ |---|---|---|---|---|
157
+ | FRANCÈS | Nouvelle Génération A1-A2 | | Santillana | 9788490494745 |
158
+ | CIÈNCIES EXPERIMENTALS | Science Bits<br>ES GESTIONA A TRAVÉS DE L'AMPA AL SETEMBRE | | | 9788412213485 (llicència digital) |
159
+ | MATEMÀTIQUES | Projecte MOTIMATS-ONMAT (llicència digital) Tekman Books<br>ES GESTIONA A TRAVÉS DE L'AMPA AL SETEMBRE | | | |
160
+ | TECNOLOGIA | Tecnologia 1 ESO | TEIDE | | 9788430783175 |
161
+ | VISUAL I PLÀSTICA | SENSE LLIBRE-KIT DE MATERIAL | | | |
162
+ | CIÈNCIES SOCIALS | SENSE LLIBRE-dossier | | | |
163
+
164
+ Codí: 04mp02
165
+ Responsable: Coordinador Qualitat
166
+ Versió: 5
167
+ Full d'Informació a l'alumnat i famílies
168
+ Aquest document pot quedar obsolet una vegada imprès
169
+ Pàgina 2 de 2
170
+ </answer>
171
+ ```
172
+
173
+ ## Quick start:
174
+
175
+ ## vLLM:
176
+ ```
177
+ vllm serve numind/NuMarkdown-8B-Thinking --trust_remote_code --limit-mm-per-prompt image=1
178
+ ```
179
+
180
+ ```python
181
+ from openai import OpenAI
182
+ import base64
183
+
184
+ openai_api_key = "EMPTY"
185
+ openai_api_base = "http://localhost:8000/v1"
186
+
187
+ client = OpenAI(
188
+ api_key=openai_api_key,
189
+ base_url=openai_api_base,
190
+ )
191
+
192
+ def encode_image(image_path):
193
+ """
194
+ Encode the image file to base64 string
195
+ """
196
+ with open(image_path, "rb") as image_file:
197
+ return base64.b64encode(image_file.read()).decode('utf-8')
198
+
199
+ base64_image = encode_image("image.png")
200
+ data_url = f"data:image/jpeg;base64,{base64_image}"
201
+
202
+ chat_response = client.chat.completions.create(
203
+ model="numind/NuMarkdown-8B-Thinking",
204
+ temperature=0.7,
205
+ messages=[
206
+ {
207
+ "role": "user",
208
+ "content": [
209
+ {
210
+ "type": "image_url",
211
+ "image_url": {"url": data_url},
212
+ "min_pixels": 100 * 28 * 28,
213
+ "max_pixels": 5000 * 28 * 28,
214
+ },
215
+ ],
216
+ },
217
+ ]
218
+ )
219
+
220
+ result = chat_response.choices[0].message.content
221
+ reasoning = result.split("<think>")[1].split("</think>")[0]
222
+ answer = result.split("<answer>")[1].split("</answer>")[0]
223
+ print(answer)
224
+ ```
225
+
226
+
227
+ ## 🤗 Transformers:
228
+ ```python
229
+ import torch
230
+ from PIL import Image
231
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
232
+
233
+ model_id = "numind/NuMarkdown-8B-reasoning"
234
+
235
+ processor = AutoProcessor.from_pretrained(
236
+ model_id,
237
+ trust_remote_code=True,
238
+ min_pixels=100*28*28, max_pixels=5000*28*28
239
+ )
240
+
241
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
242
+ model_id,
243
+ torch_dtype=torch.bfloat16,
244
+ attn_implementation="flash_attention_2",
245
+ device_map="auto",
246
+ trust_remote_code=True,
247
+ )
248
+
249
+ img = Image.open("image.png").convert("RGB")
250
+ messages = [{
251
+ "role": "user",
252
+ "content": [
253
+ {"type": "image"},
254
+ ],
255
+ }]
256
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
257
+ model_input = processor(text=prompt, images=[img], return_tensors="pt").to(model.device)
258
+
259
+ with torch.no_grad():
260
+ model_output = model.generate(**model_input, temperature = 0.7, max_new_tokens=5000)
261
+
262
+ result = processor.decode(model_output[0])
263
+ reasoning = result.split("<think>")[1].split("</think>")[0]
264
+ answer = result.split("<answer>")[1].split("</answer>")[0]
265
+ print(answer)
266
+ ```
bar plot.png ADDED
ex1.png ADDED

Git LFS Details

  • SHA256: 9ab65794a94eae69f761e65fa4731829251b907b60488bab62b47b5c16cc7000
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
export_vision.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ import math
5
+ import argparse
6
+ import torch.nn.functional as F
7
+ from transformers import AutoModel
8
+
9
+ class minicpm_v_2_6_vision(torch.nn.Module):
10
+ def __init__(self, vlm, batch_size, in_h, in_w):
11
+ super(minicpm_v_2_6_vision, self).__init__()
12
+ self.vpm = vlm.vpm
13
+ self.resampler = vlm.resampler
14
+ patch_size = vlm.config.patch_size
15
+ num_patches_per_side = vlm.vpm.embeddings.num_patches_per_side
16
+ tgt_sizes = torch.Tensor([[(in_h // patch_size), math.ceil(in_w / patch_size)]]).type(torch.int32)
17
+ patch_attention_mask = torch.ones(
18
+ size=(batch_size, in_h // patch_size, in_w // patch_size),
19
+ dtype=torch.bool, device=vlm.device,
20
+ )
21
+ max_im_h, max_im_w = in_h, in_w
22
+ max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size
23
+ boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side)
24
+ position_ids = torch.full(
25
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w),
26
+ fill_value=0,
27
+ )
28
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
29
+ if tgt_sizes is not None:
30
+ nb_patches_h = tgt_sizes[batch_idx][0]
31
+ nb_patches_w = tgt_sizes[batch_idx][1]
32
+ else:
33
+ nb_patches_h = p_attn_mask[:, 0].sum()
34
+ nb_patches_w = p_attn_mask[0].sum()
35
+
36
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
37
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
38
+
39
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
40
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
41
+
42
+ pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
43
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
44
+
45
+ position_ids = position_ids.to(vlm.device)
46
+ self.position_ids = position_ids
47
+
48
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
49
+ max_patch_len = torch.max(patch_len)
50
+ key_padding_mask = torch.zeros((batch_size, max_patch_len), dtype=torch.bool, device=vlm.device)
51
+ pos_embed = []
52
+ for i in range(batch_size):
53
+ tgt_h, tgt_w = tgt_sizes[i]
54
+ pos_embed.append(self.resampler.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(torch.float32)) # patches * D
55
+ key_padding_mask[i, patch_len[i]:] = True
56
+
57
+ self.pos_embed = torch.nn.utils.rnn.pad_sequence(
58
+ pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
59
+
60
+ def forward(self, pixel_values):
61
+ batch_size = pixel_values.size(0)
62
+ # patch embedding
63
+ patch_embeds = self.vpm.embeddings.patch_embedding(pixel_values)
64
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
65
+ hidden_states = embeddings + self.vpm.embeddings.position_embedding(self.position_ids)
66
+ # encoder
67
+ encoder_outputs = self.vpm.encoder(inputs_embeds=hidden_states)
68
+ last_hidden_state = encoder_outputs[0]
69
+ last_hidden_state = self.vpm.post_layernorm(last_hidden_state)
70
+ # resampler
71
+ x = self.resampler.kv_proj(last_hidden_state) # B * L * D
72
+ x = self.resampler.ln_kv(x).permute(1, 0, 2) # L * B * D
73
+
74
+ q = self.resampler.ln_q(self.resampler.query) # Q * D
75
+
76
+ out = self.resampler.attn(
77
+ self.resampler._repeat(q, batch_size), # Q * B * D
78
+ x + self.pos_embed, # L * B * D + L * B * D
79
+ x)[0]
80
+ # out: Q * B * D
81
+ x = out.permute(1, 0, 2) # B * Q * D
82
+
83
+ x = self.resampler.ln_post(x)
84
+ x = x @ self.resampler.proj
85
+ return x
86
+
87
+ class qwen2_5_vl_3b_vision(torch.nn.Module):
88
+ def __init__(self, vlm, batch_size):
89
+ super(qwen2_5_vl_3b_vision, self).__init__()
90
+ self.merge_size = 2
91
+ self.temporal_patch_size = 2
92
+ self.patch_size = 14
93
+ self.channel = 3
94
+ self.vpm = vlm.visual
95
+ self.batch_size = batch_size
96
+
97
+ def forward(self, pixel_value, grid_thw):
98
+ if self.batch_size == 1:
99
+ patches = pixel_value.repeat(self.temporal_patch_size, 1, 1, 1)
100
+ elif self.batch_size % self.temporal_patch_size == 1:
101
+ repeat_image = pixel_value[-1:, ...].repeat(2, 1, 1, 1)
102
+ patches = torch.cat((pixel_value, repeat_image), dim=0)
103
+ else:
104
+ patches = pixel_value
105
+ grid_t, grid_h, grid_w = grid_thw[0][0], grid_thw[0][1], grid_thw[0][2]
106
+ patches = patches.reshape(grid_t, self.temporal_patch_size, self.channel,
107
+ grid_h//self.merge_size, self.merge_size, self.patch_size, grid_w//self.merge_size, self.merge_size, self.patch_size)
108
+ patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
109
+ flatten_patches = patches.reshape(grid_t * grid_h * grid_w, self.channel * self.temporal_patch_size * self.patch_size * self.patch_size)
110
+
111
+ return self.vpm(flatten_patches, grid_thw)
112
+
113
+ class smolvlm_vision(torch.nn.Module):
114
+ def __init__(self, vlm):
115
+ super(smolvlm_vision, self).__init__()
116
+ self.vpm = vlm.model.vision_model
117
+ self.connector = vlm.model.connector
118
+
119
+ def forward(self, pixel_values):
120
+ # Get sequence from the vision encoder
121
+ image_hidden_states = self.vpm(pixel_values).last_hidden_state
122
+ # Modality projection & resampling
123
+ image_hidden_states = self.connector(image_hidden_states)
124
+ print("image_features:", image_hidden_states.shape)
125
+ return image_hidden_states
126
+
127
+ class vila1_5_3b_vision(torch.nn.Module):
128
+ def __init__(self, vlm):
129
+ super(vila1_5_3b_vision, self).__init__()
130
+ self.vlm = vlm
131
+
132
+ def forward(self, pixel_values):
133
+ # Get sequence from the vision encoder
134
+ out = self.vlm.encode_images(pixel_values)
135
+ return out
136
+
137
+ if __name__ == "__main__":
138
+ argparse = argparse.ArgumentParser()
139
+ argparse.add_argument('--path', type=str, default='CKPT/MiniCPM-V-2_6', help='model path', required=False)
140
+ argparse.add_argument('--model_name', type=str, default='minicpm-v-2_6', help='model name', required=False)
141
+ argparse.add_argument('--batch_size', type=int, default=1, help='batch size', required=False)
142
+ argparse.add_argument('--height', type=int, default=448, help='image height', required=False)
143
+ argparse.add_argument('--width', type=int, default=448, help='image width', required=False)
144
+ argparse.add_argument('--device', type=str, default="cpu", help='cpu or cuda', required=False)
145
+ args = argparse.parse_args()
146
+
147
+ path = args.path
148
+ model_name = args.model_name
149
+ savepath = os.path.join("./onnx", model_name + "_vision.onnx")
150
+ device_type = args.device
151
+ os.makedirs(os.path.dirname(savepath), exist_ok=True)
152
+ if model_name == 'minicpm-v-2_6':
153
+ model = AutoModel.from_pretrained(
154
+ path, trust_remote_code=True, torch_dtype=torch.float32,
155
+ )
156
+ model = model.to(device=device_type, dtype=torch.float32)
157
+ model.eval()
158
+ model = minicpm_v_2_6_vision(model, args.batch_size, args.height, args.width)
159
+ pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32)
160
+ out = model(pixel_values)
161
+ print("Output shape:", out.shape)
162
+ torch.onnx.export(model,
163
+ pixel_values,
164
+ savepath,
165
+ input_names=['pixel'],
166
+ opset_version=15)
167
+ elif model_name == 'qwen2_5-vl-3b':
168
+ from transformers import Qwen2_5_VLForConditionalGeneration
169
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
170
+ path,
171
+ low_cpu_mem_usage=True,
172
+ _attn_implementation="eager",
173
+ trust_remote_code=True
174
+ )
175
+
176
+ model = model.to(device=device_type, dtype=torch.float32).eval()
177
+
178
+ model = qwen2_5_vl_3b_vision(model, args.batch_size)
179
+
180
+
181
+ def get_window_index_static(self, grid_thw):
182
+ # grid_thw: [1, T, H, W] (int64, static)
183
+ device = grid_thw.device
184
+ T, H, W = grid_thw[0]
185
+
186
+ total = T * H * W
187
+
188
+ # window_index: [total]
189
+ window_index = torch.arange(total, device=device)
190
+
191
+ # cu_window_seqlens: [0, total]
192
+ cu_window_seqlens = torch.tensor([0, total], device=device)
193
+
194
+ return window_index, cu_window_seqlens
195
+
196
+
197
+ # 🔥 APPLY PATCH HERE
198
+ model.visual.get_window_index = get_window_index_static.__get__(
199
+ model.visual, type(model.visual)
200
+ )
201
+
202
+ print(model.vpm.get_window_index)
203
+
204
+
205
+
206
+
207
+ pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32)
208
+ #grid_thw = torch.tensor([[args.batch_size // 2 if args.batch_size% 2 == 0 else args.batch_size // 2 + 1, args.height//14, args.width//14]], dtype=torch.int64)
209
+ # model.eval()
210
+
211
+ out = model(pixel_values, grid_thw)
212
+ print("Output shape:", out.shape)
213
+ # FIXED grid
214
+ grid_thw = torch.tensor([[2, 32, 32]], dtype=torch.int64) # example
215
+
216
+ torch.onnx.export(
217
+ model,
218
+ (pixel_values, grid_thw),
219
+ savepath,
220
+ input_names=["pixel", "grid_thw"],
221
+ opset_version=18,
222
+ #dynamic_axes=None, # 🚨 important
223
+ )
224
+
225
+ # torch.onnx.export(model,
226
+ # (pixel_values, grid_thw),
227
+ # savepath,
228
+ # input_names=['pixel', 'grid_thw'],
229
+ # dynamic_axes={'pixel': {2: 'height', 3: 'width'}},
230
+ # opset_version=18)
231
+ elif model_name == 'smolvlm':
232
+ from transformers import SmolVLMForConditionalGeneration
233
+ model = SmolVLMForConditionalGeneration.from_pretrained(
234
+ path,
235
+ torch_dtype=torch.float32,
236
+ _attn_implementation="eager",
237
+ ).to(device_type)
238
+ pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32)
239
+ print("pixel_values:", pixel_values.shape)
240
+ model = smolvlm_vision(model)
241
+ model = model.to(torch.float32).eval()
242
+ out = model(pixel_values)
243
+ torch.onnx.export(model,
244
+ pixel_values,
245
+ savepath,
246
+ input_names=['pixel'],
247
+ dynamic_axes={'pixel': {2: 'height', 3: 'width'}},
248
+ opset_version=15)
249
+ elif model_name == 'internvl3-1b':
250
+ model = AutoModel.from_pretrained(
251
+ path,
252
+ torch_dtype=torch.float32,
253
+ low_cpu_mem_usage=True,
254
+ trust_remote_code=True).eval().to(device_type)
255
+ pixel_values = torch.randn(args.batch_size, 3, args.height, args.width, device=model.device, dtype=torch.float32)
256
+ model.forward = model.extract_feature
257
+ model = model.to(torch.float32).eval()
258
+ torch.onnx.export(model, pixel_values, savepath)
259
+ else:
260
+ raise ValueError(f"Unsupported model name: {model_name}")
261
+ exit(1)
262
+
263
+ print(f"Exported to {savepath}")
export_vision_rknn.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rknn.api import RKNN
2
+ import numpy as np
3
+ import os
4
+ import argparse
5
+
6
+ argparse = argparse.ArgumentParser()
7
+ argparse.add_argument('--path', type=str, default='./onnx/qwen2_5-vl-3b_vision.onnx', help='model path', required=False)
8
+ argparse.add_argument('--model_name', type=str, default='qwen2_5-vl-3b', help='model name', required=False)
9
+ argparse.add_argument('--target-platform', type=str, default='rk3588', help='target platform', required=False)
10
+ argparse.add_argument('--batch_size', type=int, default=1, help='batch size', required=False)
11
+ argparse.add_argument('--height', type=int, default=448, help='image height', required=False)
12
+ argparse.add_argument('--width', type=int, default=448, help='image width', required=False)
13
+
14
+ args = argparse.parse_args()
15
+
16
+ model_path = args.path
17
+ target_platform = args.target_platform
18
+ modelname = args.model_name
19
+
20
+ if 'qwen2' in model_path.lower():
21
+ mean_value = [[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255]]
22
+ std_value = [[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]]
23
+ elif 'internvl3' in model_path.lower():
24
+ mean_value = [[0.485 * 255, 0.456 * 255, 0.406 * 255]]
25
+ std_value = [[0.229 * 255, 0.224 * 255, 0.225 * 255]]
26
+ else:
27
+ mean_value = [[0.5 * 255, 0.5 * 255, 0.5 * 255]]
28
+ std_value = [[0.5 * 255, 0.5 * 255, 0.5 * 255]]
29
+
30
+ if modelname == 'qwen2_5-vl-3b':
31
+ inputs = ['pixel', 'grid_thw']
32
+ input_size_list = [[args.batch_size, 3, args.height, args.width], [1,3]]
33
+ grid_t = args.batch_size//2 if args.batch_size % 2 == 0 else (args.batch_size + 1)//2
34
+ input_initial_val = [None, np.array([[grid_t, args.height//14, args.width//14]], dtype=np.int64)]
35
+ op_target = {"/vpm/patch_embed/proj/Conv_output_0_conv_tp_sw": 'cpu'}
36
+ elif modelname == 'qwen3-vl':
37
+ inputs = ['pixel', 'grid_thw']
38
+ input_size_list = [[args.batch_size, 3, args.height, args.width], [1,3]]
39
+ grid_t = args.batch_size//2 if args.batch_size % 2 == 0 else (args.batch_size + 1)//2
40
+ input_initial_val = [None, np.array([[grid_t, args.height//16, args.width//16]], dtype=np.int64)]
41
+ op_target = None
42
+ else:
43
+ inputs = ['pixel']
44
+ input_size_list = [[args.batch_size, 3, args.height, args.width]]
45
+ input_initial_val = None
46
+ op_target = None
47
+
48
+ if modelname == 'deepseekocr':
49
+ disable_rules=['convert_rs_add_rs_to_rs_gather_elements']
50
+ else:
51
+ disable_rules=[]
52
+
53
+ rknn = RKNN(verbose=False)
54
+ rknn.config(disable_rules=disable_rules, target_platform=target_platform, mean_values=mean_value, std_values=std_value, op_target=op_target)
55
+ rknn.load_onnx(model_path, inputs=inputs, input_size_list=input_size_list, input_initial_val=input_initial_val)
56
+ rknn.build(do_quantization=False, dataset=None)
57
+ os.makedirs("rknn", exist_ok=True)
58
+ rknn.export_rknn("./rknn/" + os.path.splitext(os.path.basename(model_path))[0] + "_{}.rknn".format(target_platform))
matrix.png ADDED