prithivMLmods commited on
Commit
7abceaa
·
verified ·
1 Parent(s): 33cd763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -60
app.py CHANGED
@@ -1,29 +1,20 @@
1
  import os
2
  import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
  from threading import Thread
8
  from typing import Iterable
9
 
10
  import gradio as gr
11
  import spaces
12
  import torch
13
- import numpy as np
14
- from PIL import Image, ImageOps
15
- import requests
16
-
17
  from transformers import (
18
  Qwen2_5_VLForConditionalGeneration,
19
  AutoModelForCausalLM,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
23
- from transformers.image_utils import load_image
24
  from gradio.themes import Soft
25
  from gradio.themes.utils import colors, fonts, sizes
26
- from huggingface_hub import snapshot_download
27
 
28
  # --- Theme and CSS Definition ---
29
 
@@ -106,7 +97,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
106
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
107
 
108
  # Load Nanonets-OCR-s
109
- MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
110
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
111
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
112
  MODEL_ID_M,
@@ -115,31 +106,26 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
115
  ).to(device).eval()
116
 
117
  # Load Dots.OCR
118
- MODEL_ID_D = "rednote-hilab/dots.ocr"
119
- model_path_d = "./models/dots-ocr-local"
120
- snapshot_download(
121
- repo_id=MODEL_ID_D,
122
- local_dir=model_path_d,
123
- local_dir_use_symlinks=False,
124
- )
125
  model_d = AutoModelForCausalLM.from_pretrained(
126
- model_path_d,
127
- attn_implementation="flash_attention_2" if "cuda" in device.type else "eager",
128
  torch_dtype=torch.bfloat16,
129
  device_map="auto",
130
  trust_remote_code=True
131
- )
132
- processor_d = AutoProcessor.from_pretrained(
133
- model_path_d,
134
- trust_remote_code=True
135
- )
136
 
137
 
138
  @spaces.GPU
139
  def generate_image(model_name: str, text: str, image: Image.Image,
140
- max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
 
 
 
 
141
  """Generate responses for image input using the selected model."""
142
- if model_name == "Nanonets-OCR2-3B":
143
  processor, model = processor_m, model_m
144
  elif model_name == "Dots.OCR":
145
  processor, model = processor_d, model_d
@@ -151,18 +137,16 @@ def generate_image(model_name: str, text: str, image: Image.Image,
151
  yield "Please upload an image.", "Please upload an image."
152
  return
153
 
154
- images = [image]
155
 
156
  messages = [
157
  {
158
  "role": "user",
159
- "content": [{"type": "image"}] * len(images) + [
160
- {"type": "text", "text": text}
161
- ]
162
  }
163
  ]
164
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
165
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
166
 
167
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
168
  generation_kwargs = {
@@ -174,49 +158,31 @@ def generate_image(model_name: str, text: str, image: Image.Image,
174
  "top_k": top_k,
175
  "repetition_penalty": repetition_penalty,
176
  }
177
-
178
- # Dots.OCR uses a different generation parameter name for end-of-sequence
179
- if "dots.ocr" in model.config.name_or_path.lower():
180
- generation_kwargs["eos_token_id"] = processor.tokenizer.eos_token_id
181
-
182
-
183
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
184
  thread.start()
185
 
186
  buffer = ""
187
  for new_text in streamer:
188
- buffer += new_text.replace("<|im_end|>", "").replace("</s>", "")
189
  yield buffer, buffer
190
 
191
- # The formatted output is the same as the raw output in this version
192
- yield buffer, buffer
193
-
194
-
195
  # Define examples for image inference
196
  image_examples = [
197
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
198
  ["Describe the image!", "images/8.png"],
199
  ["OCR the image", "images/2.jpg"],
200
- ["Convert this page to markdown", "images/1.png"],
201
  ]
202
 
203
-
204
  # Create the Gradio Interface
205
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
206
- gr.Markdown("# **Multimodal Image OCR**", elem_id="main-title")
207
  with gr.Row():
208
  with gr.Column(scale=2):
209
- model_choice = gr.Radio(
210
- choices=["Nanonets-OCR2-3B", "Dots.OCR"],
211
- label="Select Model",
212
- value="Nanonets-OCR-s"
213
- )
214
- query_input = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
215
  image_upload = gr.Image(type="pil", label="Upload Image", height=320)
216
- submit_button = gr.Button("Submit", variant="primary")
 
217
 
218
- gr.Examples(examples=image_examples, inputs=[query_input, image_upload])
219
-
220
  with gr.Accordion("Advanced options", open=False):
221
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
222
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
@@ -226,14 +192,21 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
226
 
227
  with gr.Column(scale=3):
228
  gr.Markdown("## Output", elem_id="output-title")
229
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=18, show_copy_button=True)
230
- formatted_output = gr.Markdown(label="Formatted Output (Result.md)")
 
 
 
 
 
 
 
231
 
232
- submit_button.click(
233
  fn=generate_image,
234
- inputs=[model_choice, query_input, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
235
  outputs=[raw_output, formatted_output]
236
  )
237
 
238
  if __name__ == "__main__":
239
- demo.queue(max_size=50).launch(ssr_mode=False, show_error=True)
 
1
  import os
2
  import random
 
 
 
 
3
  from threading import Thread
4
  from typing import Iterable
5
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ from PIL import Image
 
 
 
10
  from transformers import (
11
  Qwen2_5_VLForConditionalGeneration,
12
  AutoModelForCausalLM,
13
  AutoProcessor,
14
  TextIteratorStreamer,
15
  )
 
16
  from gradio.themes import Soft
17
  from gradio.themes.utils import colors, fonts, sizes
 
18
 
19
  # --- Theme and CSS Definition ---
20
 
 
97
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
 
99
  # Load Nanonets-OCR-s
100
+ MODEL_ID_M = "nanonets/Nanonets-OCR-s"
101
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
102
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
103
  MODEL_ID_M,
 
106
  ).to(device).eval()
107
 
108
  # Load Dots.OCR
109
+ MODEL_PATH_D = "rednote-hilab/dots.ocr"
110
+ processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
 
 
 
 
 
111
  model_d = AutoModelForCausalLM.from_pretrained(
112
+ MODEL_PATH_D,
113
+ attn_implementation="flash_attention_2",
114
  torch_dtype=torch.bfloat16,
115
  device_map="auto",
116
  trust_remote_code=True
117
+ ).eval()
 
 
 
 
118
 
119
 
120
  @spaces.GPU
121
  def generate_image(model_name: str, text: str, image: Image.Image,
122
+ max_new_tokens: int = 1024,
123
+ temperature: float = 0.6,
124
+ top_p: float = 0.9,
125
+ top_k: int = 50,
126
+ repetition_penalty: float = 1.2):
127
  """Generate responses for image input using the selected model."""
128
+ if model_name == "Nanonets-OCR-s":
129
  processor, model = processor_m, model_m
130
  elif model_name == "Dots.OCR":
131
  processor, model = processor_d, model_d
 
137
  yield "Please upload an image.", "Please upload an image."
138
  return
139
 
140
+ images = [image.convert("RGB")]
141
 
142
  messages = [
143
  {
144
  "role": "user",
145
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
 
 
146
  }
147
  ]
148
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
149
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
150
 
151
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
152
  generation_kwargs = {
 
158
  "top_k": top_k,
159
  "repetition_penalty": repetition_penalty,
160
  }
 
 
 
 
 
 
161
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
162
  thread.start()
163
 
164
  buffer = ""
165
  for new_text in streamer:
166
+ buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
167
  yield buffer, buffer
168
 
 
 
 
 
169
  # Define examples for image inference
170
  image_examples = [
171
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
172
  ["Describe the image!", "images/8.png"],
173
  ["OCR the image", "images/2.jpg"],
 
174
  ]
175
 
 
176
  # Create the Gradio Interface
177
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
178
+ gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
179
  with gr.Row():
180
  with gr.Column(scale=2):
181
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
 
 
 
 
 
182
  image_upload = gr.Image(type="pil", label="Upload Image", height=320)
183
+ image_submit = gr.Button("Submit", variant="primary")
184
+ gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
185
 
 
 
186
  with gr.Accordion("Advanced options", open=False):
187
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
188
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
 
192
 
193
  with gr.Column(scale=3):
194
  gr.Markdown("## Output", elem_id="output-title")
195
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=13, show_copy_button=True)
196
+ with gr.Accordion("Formatted Result", open=True):
197
+ formatted_output = gr.Markdown(label="Formatted Result")
198
+
199
+ model_choice = gr.Radio(
200
+ choices=["Nanonets-OCR-s", "Dots.OCR"],
201
+ label="Select Model",
202
+ value="Nanonets-OCR-s"
203
+ )
204
 
205
+ image_submit.click(
206
  fn=generate_image,
207
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
208
  outputs=[raw_output, formatted_output]
209
  )
210
 
211
  if __name__ == "__main__":
212
+ demo.queue(max_size=50).launch(show_error=True)