prithivMLmods commited on
Commit
1e73c4d
·
verified ·
1 Parent(s): 846f854

update app

Browse files
Files changed (1) hide show
  1. app.py +60 -36
app.py CHANGED
@@ -14,9 +14,8 @@ from PIL import Image
14
  import cv2
15
 
16
  from transformers import (
17
- Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
- AutoModelForImageTextToText,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
@@ -133,13 +132,14 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
133
  torch_dtype=torch.float16
134
  ).to(device).eval()
135
 
136
- # Load Qwen2-VL-OCR-2B-Instruct
137
- MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
138
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
139
- model_x = Qwen2VLForConditionalGeneration.from_pretrained(
140
- MODEL_ID_X,
141
  trust_remote_code=True,
142
- torch_dtype=torch.float16
 
143
  ).to(device).eval()
144
 
145
 
@@ -151,20 +151,20 @@ def generate_image(model_name: str, text: str, image: Image.Image,
151
  Generates responses using the selected model for image input.
152
  Yields raw text and Markdown-formatted text.
153
  """
154
- if model_name == "Qwen2-VL-OCR-2B":
155
- processor = processor_x
156
- model = model_x
157
- elif model_name == "Nanonets-OCR2-3B":
 
158
  processor = processor_v
159
  model = model_v
 
 
 
160
  else:
161
  yield "Invalid model selected.", "Invalid model selected."
162
  return
163
 
164
- if image is None:
165
- yield "Please upload an image.", "Please upload an image."
166
- return
167
-
168
  messages = [{
169
  "role": "user",
170
  "content": [
@@ -180,25 +180,49 @@ def generate_image(model_name: str, text: str, image: Image.Image,
180
  return_tensors="pt",
181
  padding=True).to(device)
182
 
183
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
184
- generation_kwargs = {
185
- **inputs,
186
- "streamer": streamer,
187
- "max_new_tokens": max_new_tokens,
188
- "do_sample": True,
189
- "temperature": temperature,
190
- "top_p": top_p,
191
- "top_k": top_k,
192
- "repetition_penalty": repetition_penalty,
193
- }
194
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
195
- thread.start()
196
- buffer = ""
197
- for new_text in streamer:
198
- buffer += new_text
199
- buffer = buffer.replace("<|im_end|>", "")
200
- time.sleep(0.01)
201
- yield buffer, buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
 
204
  # Define examples for image inference
@@ -237,7 +261,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
237
  markdown_output = gr.Markdown(label="(Result.Md)")
238
 
239
  model_choice = gr.Radio(
240
- choices=["Nanonets-OCR2-3B", "Qwen2-VL-OCR-2B"],
241
  label="Select Model",
242
  value="Nanonets-OCR2-3B"
243
  )
 
14
  import cv2
15
 
16
  from transformers import (
 
17
  Qwen2_5_VLForConditionalGeneration,
18
+ AutoModelForCausalLM, # Added for Dots.OCR
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
 
132
  torch_dtype=torch.float16
133
  ).to(device).eval()
134
 
135
+ # Load Dots.OCR
136
+ MODEL_ID_D = "rednote-hilab/dots.ocr"
137
+ processor_d = AutoProcessor.from_pretrained(MODEL_ID_D, trust_remote_code=True)
138
+ model_d = AutoModelForCausalLM.from_pretrained(
139
+ MODEL_ID_D,
140
  trust_remote_code=True,
141
+ torch_dtype=torch.float16,
142
+ attn_implementation="flash_attention_2"
143
  ).to(device).eval()
144
 
145
 
 
151
  Generates responses using the selected model for image input.
152
  Yields raw text and Markdown-formatted text.
153
  """
154
+ if image is None:
155
+ yield "Please upload an image.", "Please upload an image."
156
+ return
157
+
158
+ if model_name == "Nanonets-OCR2-3B":
159
  processor = processor_v
160
  model = model_v
161
+ elif model_name == "Dots.OCR":
162
+ processor = processor_d
163
+ model = model_d
164
  else:
165
  yield "Invalid model selected.", "Invalid model selected."
166
  return
167
 
 
 
 
 
168
  messages = [{
169
  "role": "user",
170
  "content": [
 
180
  return_tensors="pt",
181
  padding=True).to(device)
182
 
183
+ # Nanonets model supports streaming
184
+ if model_name == "Nanonets-OCR2-3B":
185
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
186
+ generation_kwargs = {
187
+ **inputs,
188
+ "streamer": streamer,
189
+ "max_new_tokens": max_new_tokens,
190
+ "do_sample": True,
191
+ "temperature": temperature,
192
+ "top_p": top_p,
193
+ "top_k": top_k,
194
+ "repetition_penalty": repetition_penalty,
195
+ }
196
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
197
+ thread.start()
198
+ buffer = ""
199
+ for new_text in streamer:
200
+ buffer += new_text
201
+ buffer = buffer.replace("<|im_end|>", "")
202
+ time.sleep(0.01)
203
+ yield buffer, buffer
204
+
205
+ # Dots.OCR does not use the streamer in the same way, generate full response
206
+ elif model_name == "Dots.OCR":
207
+ generation_kwargs = {
208
+ **inputs,
209
+ "max_new_tokens": max_new_tokens,
210
+ "do_sample": True,
211
+ "temperature": temperature,
212
+ "top_p": top_p,
213
+ "top_k": top_k,
214
+ "repetition_penalty": repetition_penalty,
215
+ }
216
+ generated_ids = model.generate(**generation_kwargs)
217
+ generated_ids_trimmed = [
218
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
219
+ ]
220
+ output_text = processor.batch_decode(
221
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
222
+ )[0]
223
+
224
+ output_text = output_text.replace("<|im_end|>", "").strip()
225
+ yield output_text, output_text
226
 
227
 
228
  # Define examples for image inference
 
261
  markdown_output = gr.Markdown(label="(Result.Md)")
262
 
263
  model_choice = gr.Radio(
264
+ choices=["Nanonets-OCR2-3B", "Dots.OCR"],
265
  label="Select Model",
266
  value="Nanonets-OCR2-3B"
267
  )