prithivMLmods commited on
Commit
5bcfcb2
·
verified ·
1 Parent(s): bbef9e6

update app

Browse files
Files changed (1) hide show
  1. app.py +103 -80
app.py CHANGED
@@ -1,5 +1,9 @@
1
  import os
2
  import sys
 
 
 
 
3
  from threading import Thread
4
  from typing import Iterable
5
  from huggingface_hub import snapshot_download
@@ -7,7 +11,10 @@ from huggingface_hub import snapshot_download
7
  import gradio as gr
8
  import spaces
9
  import torch
 
10
  from PIL import Image
 
 
11
  from transformers import (
12
  Qwen2_5_VLForConditionalGeneration,
13
  Qwen3VLForConditionalGeneration,
@@ -17,6 +24,7 @@ from transformers import (
17
  TextIteratorStreamer,
18
  )
19
 
 
20
  from gradio.themes import Soft
21
  from gradio.themes.utils import colors, fonts, sizes
22
 
@@ -67,8 +75,14 @@ class SteelBlueTheme(Soft):
67
  button_primary_text_color_hover="white",
68
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
69
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
70
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
71
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
 
 
 
 
 
 
72
  slider_color="*secondary_500",
73
  slider_color_dark="*secondary_600",
74
  block_title_text_weight="600",
@@ -91,6 +105,22 @@ css = """
91
  }
92
  """
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  CACHE_PATH = "./model_cache"
96
  if not os.path.exists(CACHE_PATH):
@@ -131,35 +161,24 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
131
 
132
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133
 
134
- # Load chandra
135
- MODEL_ID_C = "datalab-to/chandra"
136
- processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
137
- model_c = Qwen3VLForConditionalGeneration.from_pretrained(
138
- MODEL_ID_C,
139
  trust_remote_code=True,
140
  torch_dtype=torch.float16
141
  ).to(device).eval()
142
 
143
  # Load Nanonets-OCR2-3B
144
- MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
145
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
146
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
147
- MODEL_ID_M,
148
  trust_remote_code=True,
149
  torch_dtype=torch.float16
150
  ).to(device).eval()
151
 
152
- # Load Nanonets-OCR2-1.5B-exp
153
- MODEL_ID_N = "strangervisionhf/excess_layer_pruned-nanonets-1.5b" # -> https://huggingface.co/nanonets/Nanonets-OCR2-1.5B-exp
154
- processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
155
- model_n = AutoModelForImageTextToText.from_pretrained(
156
- MODEL_ID_N,
157
- trust_remote_code=True,
158
- torch_dtype=torch.float16,
159
- attn_implementation="flash_attention_2"
160
- ).to(device).eval()
161
-
162
-
163
  # Load Dots.OCR from the local, patched directory
164
  MODEL_PATH_D = model_path_d_local
165
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
@@ -171,33 +190,35 @@ model_d = AutoModelForCausalLM.from_pretrained(
171
  trust_remote_code=True
172
  ).eval()
173
 
174
- # Load PaddleOCR
175
- MODEL_ID_P = "strangervisionhf/paddle" # -> https://huggingface.co/PaddlePaddle/PaddleOCR-VL
176
- processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
177
- model_p = AutoModelForCausalLM.from_pretrained(
178
- MODEL_ID_P,
179
  trust_remote_code=True,
180
- torch_dtype=torch.bfloat16
181
  ).to(device).eval()
182
 
183
  @spaces.GPU
184
  def generate_image(model_name: str, text: str, image: Image.Image,
185
- max_new_tokens: int = 1024,
186
- temperature: float = 0.6,
187
- top_p: float = 0.9,
188
- top_k: int = 50,
189
- repetition_penalty: float = 1.2):
190
- """Generate responses for image input using the selected model."""
191
- if model_name == "Nanonets-OCR2-3B":
192
- processor, model = processor_m, model_m
193
- elif model_name == "Nanonets-OCR2-1.5B(exp)":
194
- processor, model = processor_n, model_n
195
- elif model_name == "Dots.OCR":
196
- processor, model = processor_d, model_d
197
- elif model_name == "PaddleOCR":
198
- processor, model = processor_p, model_p
199
  elif model_name == "Chandra-OCR":
200
- processor, model = processor_c, model_c
 
 
 
 
201
  else:
202
  yield "Invalid model selected.", "Invalid model selected."
203
  return
@@ -206,40 +227,39 @@ def generate_image(model_name: str, text: str, image: Image.Image,
206
  yield "Please upload an image.", "Please upload an image."
207
  return
208
 
209
- images = [image.convert("RGB")]
210
-
211
- if model_name == "PaddleOCR":
212
- messages = [
213
- {"role": "user", "content": text}
214
  ]
215
- else:
216
- messages = [
217
- {
218
- "role": "user",
219
- "content": [{"type": "image"}] + [{"type": "text", "text": text}]
220
- }
221
- ]
222
-
223
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
224
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
225
 
226
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
227
  generation_kwargs = {
228
  **inputs,
229
  "streamer": streamer,
230
  "max_new_tokens": max_new_tokens,
 
231
  "temperature": temperature,
232
  "top_p": top_p,
233
  "top_k": top_k,
234
  "repetition_penalty": repetition_penalty,
235
- "do_sample": True
236
  }
237
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
238
  thread.start()
239
-
240
  buffer = ""
241
  for new_text in streamer:
242
- buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
 
 
243
  yield buffer, buffer
244
 
245
  image_examples = [
@@ -253,34 +273,37 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
256
- image_upload = gr.Image(type="pil", label="Upload Image", height=320)
257
- image_submit = gr.Button("Submit", variant="primary")
258
- gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
259
 
 
 
 
 
 
 
260
  with gr.Accordion("Advanced options", open=False):
261
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
262
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
263
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
264
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
265
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
266
-
267
  with gr.Column(scale=3):
268
- gr.Markdown("## Output", elem_id="output-title")
269
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
270
- with gr.Accordion("[Result.md]", open=False):
271
- formatted_output = gr.Markdown(label="Formatted Result")
272
-
273
- model_choice = gr.Radio(
274
- choices=["Nanonets-OCR2-3B", "Chandra-OCR", "Dots.OCR", "Nanonets-OCR2-1.5B(exp)", "PaddleOCR"],
275
- label="Select Model",
276
- value="Nanonets-OCR2-3B"
277
- )
278
- gr.Markdown("Note: Currently, PaddleOCR VL only supports OCR inference. Structured OCR document parsing transformer inference is coming soon. [Report – Bug/Issue](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR3/discussions/1)")
279
-
280
  image_submit.click(
281
  fn=generate_image,
282
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
283
- outputs=[raw_output, formatted_output]
284
  )
285
 
286
  if __name__ == "__main__":
 
1
  import os
2
  import sys
3
+ import random
4
+ import uuid
5
+ import json
6
+ import time
7
  from threading import Thread
8
  from typing import Iterable
9
  from huggingface_hub import snapshot_download
 
11
  import gradio as gr
12
  import spaces
13
  import torch
14
+ import numpy as np
15
  from PIL import Image
16
+ import cv2
17
+
18
  from transformers import (
19
  Qwen2_5_VLForConditionalGeneration,
20
  Qwen3VLForConditionalGeneration,
 
24
  TextIteratorStreamer,
25
  )
26
 
27
+ from transformers.image_utils import load_image
28
  from gradio.themes import Soft
29
  from gradio.themes.utils import colors, fonts, sizes
30
 
 
75
  button_primary_text_color_hover="white",
76
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
77
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
78
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
79
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
80
+ button_secondary_text_color="black",
81
+ button_secondary_text_color_hover="white",
82
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
83
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
84
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
85
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
86
  slider_color="*secondary_500",
87
  slider_color_dark="*secondary_600",
88
  block_title_text_weight="600",
 
105
  }
106
  """
107
 
108
+ MAX_MAX_NEW_TOKENS = 4096
109
+ DEFAULT_MAX_NEW_TOKENS = 1024
110
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
111
+
112
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
113
+
114
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
115
+ print("torch.__version__ =", torch.__version__)
116
+ print("torch.version.cuda =", torch.version.cuda)
117
+ print("cuda available:", torch.cuda.is_available())
118
+ print("cuda device count:", torch.cuda.device_count())
119
+ if torch.cuda.is_available():
120
+ print("current device:", torch.cuda.current_device())
121
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
122
+
123
+ print("Using device:", device)
124
 
125
  CACHE_PATH = "./model_cache"
126
  if not os.path.exists(CACHE_PATH):
 
161
 
162
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163
 
164
+ # Load Chandra-OCR
165
+ MODEL_ID_V = "datalab-to/chandra"
166
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
167
+ model_v = Qwen3VLForConditionalGeneration.from_pretrained(
168
+ MODEL_ID_V,
169
  trust_remote_code=True,
170
  torch_dtype=torch.float16
171
  ).to(device).eval()
172
 
173
  # Load Nanonets-OCR2-3B
174
+ MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
175
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
176
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
177
+ MODEL_ID_X,
178
  trust_remote_code=True,
179
  torch_dtype=torch.float16
180
  ).to(device).eval()
181
 
 
 
 
 
 
 
 
 
 
 
 
182
  # Load Dots.OCR from the local, patched directory
183
  MODEL_PATH_D = model_path_d_local
184
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
 
190
  trust_remote_code=True
191
  ).eval()
192
 
193
+ # Load olmOCR-2-7B-1025
194
+ MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
195
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
196
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
197
+ MODEL_ID_M,
198
  trust_remote_code=True,
199
+ torch_dtype=torch.float16
200
  ).to(device).eval()
201
 
202
  @spaces.GPU
203
  def generate_image(model_name: str, text: str, image: Image.Image,
204
+ max_new_tokens: int, temperature: float, top_p: float,
205
+ top_k: int, repetition_penalty: float):
206
+ """
207
+ Generates responses using the selected model for image input.
208
+ Yields raw text and Markdown-formatted text.
209
+ """
210
+ if model_name == "olmOCR-2-7B-1025":
211
+ processor = processor_m
212
+ model = model_m
213
+ elif model_name == "Nanonets-OCR2-3B":
214
+ processor = processor_x
215
+ model = model_x
 
 
216
  elif model_name == "Chandra-OCR":
217
+ processor = processor_v
218
+ model = model_v
219
+ elif model_name == "Dots.OCR":
220
+ processor = processor_d
221
+ model = model_d
222
  else:
223
  yield "Invalid model selected.", "Invalid model selected."
224
  return
 
227
  yield "Please upload an image.", "Please upload an image."
228
  return
229
 
230
+ messages = [{
231
+ "role": "user",
232
+ "content": [
233
+ {"type": "image"},
234
+ {"type": "text", "text": text},
235
  ]
236
+ }]
237
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
238
+
239
+ inputs = processor(
240
+ text=[prompt_full],
241
+ images=[image],
242
+ return_tensors="pt",
243
+ padding=True).to(device)
 
 
244
 
245
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
246
  generation_kwargs = {
247
  **inputs,
248
  "streamer": streamer,
249
  "max_new_tokens": max_new_tokens,
250
+ "do_sample": True,
251
  "temperature": temperature,
252
  "top_p": top_p,
253
  "top_k": top_k,
254
  "repetition_penalty": repetition_penalty,
 
255
  }
256
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
257
  thread.start()
 
258
  buffer = ""
259
  for new_text in streamer:
260
+ buffer += new_text
261
+ buffer = buffer.replace("<|im_end|>", "")
262
+ time.sleep(0.01)
263
  yield buffer, buffer
264
 
265
  image_examples = [
 
273
  with gr.Row():
274
  with gr.Column(scale=2):
275
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
276
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
 
 
277
 
278
+ image_submit = gr.Button("Submit", variant="primary")
279
+ gr.Examples(
280
+ examples=image_examples,
281
+ inputs=[image_query, image_upload]
282
+ )
283
+
284
  with gr.Accordion("Advanced options", open=False):
285
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
286
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
287
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
288
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
289
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
290
+
291
  with gr.Column(scale=3):
292
+ gr.Markdown("## Output", elem_id="output-title")
293
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
294
+ with gr.Accordion("(Result.md)", open=False):
295
+ markdown_output = gr.Markdown(label="(Result.Md)")
296
+
297
+ model_choice = gr.Radio(
298
+ choices=["Nanonets-OCR2-3B", "Chandra-OCR", "olmOCR-2-7B-1025", "Dots.OCR"],
299
+ label="Select Model",
300
+ value="Nanonets-OCR2-3B"
301
+ )
302
+
 
303
  image_submit.click(
304
  fn=generate_image,
305
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
306
+ outputs=[output, markdown_output]
307
  )
308
 
309
  if __name__ == "__main__":