prithivMLmods commited on
Commit
d50ecd0
·
verified ·
1 Parent(s): 238ed44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -132,15 +132,23 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
132
  torch_dtype=torch.float16
133
  ).to(device).eval()
134
 
135
- # Load PaddleOCR-VL [PaddlePaddle/PaddleOCR-VL]
136
- MODEL_ID_P = "strangervisionhf/paddle"
 
137
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
138
  model_p = AutoModelForCausalLM.from_pretrained(
139
  MODEL_ID_P,
140
  trust_remote_code=True,
141
- torch_dtype=torch.float16
142
  ).to(device).eval()
143
 
 
 
 
 
 
 
 
144
 
145
  @spaces.GPU
146
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -153,12 +161,11 @@ def generate_image(model_name: str, text: str, image: Image.Image,
153
  if image is None:
154
  yield "Please upload an image.", "Please upload an image."
155
  return
156
-
157
  if model_name == "Nanonets-OCR2-3B":
158
  processor = processor_v
159
  model = model_v
160
 
161
- # Nanonets/Qwen-VL format: content is a list of dicts
162
  messages = [{
163
  "role": "user",
164
  "content": [
@@ -173,7 +180,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
173
  images=[image],
174
  return_tensors="pt",
175
  padding=True).to(device)
176
-
177
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
178
  generation_kwargs = {
179
  **inputs,
@@ -193,30 +200,31 @@ def generate_image(model_name: str, text: str, image: Image.Image,
193
  buffer = buffer.replace("<|im_end|>", "")
194
  time.sleep(0.01)
195
  yield buffer, buffer
196
-
197
  elif model_name == "PaddleOCR-VL":
198
  processor = processor_p
199
  model = model_p
200
 
201
- # PaddleOCR-VL format: content is a simple string
 
 
202
  messages = [{"role": "user", "content": text}]
203
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
204
-
205
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt").to(device)
206
 
207
- # Use generation parameters from the reference script for best results
208
  generation_kwargs = {
209
  **inputs,
210
  "max_new_tokens": max_new_tokens,
211
- "do_sample": False,
212
  "use_cache": True,
213
  }
214
-
215
  with torch.inference_mode():
216
  generated_ids = model.generate(**generation_kwargs)
217
-
218
  resp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
219
- # Extract only the generated part of the response
220
  answer = resp.split(prompt_full)[-1].strip()
221
  yield answer, answer
222
 
@@ -224,12 +232,11 @@ def generate_image(model_name: str, text: str, image: Image.Image,
224
  yield "Invalid model selected.", "Invalid model selected."
225
  return
226
 
227
-
228
- # Define examples for image inference, tailored for both models
229
  image_examples = [
230
  ["OCR:", "images/ocr.png"],
231
  ["Table Recognition:", "images/4.png"],
232
- ["Extract the content from this image.", "images/0.png"]
233
  ]
234
 
235
 
@@ -238,7 +245,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
238
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
239
  with gr.Row():
240
  with gr.Column(scale=2):
241
- image_query = gr.Textbox(label="Query Input", placeholder="Enter query or task (e.g., 'OCR:')")
242
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
243
 
244
  image_submit = gr.Button("Submit", variant="primary")
 
132
  torch_dtype=torch.float16
133
  ).to(device).eval()
134
 
135
+ # Load PaddleOCR-VL
136
+ # Using the corrected model path from your previous attempt
137
+ MODEL_ID_P = "strangervisionhf/paddle"
138
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
139
  model_p = AutoModelForCausalLM.from_pretrained(
140
  MODEL_ID_P,
141
  trust_remote_code=True,
142
+ torch_dtype=torch.float16,
143
  ).to(device).eval()
144
 
145
+ # --- Task Prompts for PaddleOCR-VL ---
146
+ PROMPTS = {
147
+ "ocr": "OCR:",
148
+ "table": "Table Recognition:",
149
+ "chart": "Chart Recognition:",
150
+ "formula": "Formula Recognition:",
151
+ }
152
 
153
  @spaces.GPU
154
  def generate_image(model_name: str, text: str, image: Image.Image,
 
161
  if image is None:
162
  yield "Please upload an image.", "Please upload an image."
163
  return
164
+
165
  if model_name == "Nanonets-OCR2-3B":
166
  processor = processor_v
167
  model = model_v
168
 
 
169
  messages = [{
170
  "role": "user",
171
  "content": [
 
180
  images=[image],
181
  return_tensors="pt",
182
  padding=True).to(device)
183
+
184
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
185
  generation_kwargs = {
186
  **inputs,
 
200
  buffer = buffer.replace("<|im_end|>", "")
201
  time.sleep(0.01)
202
  yield buffer, buffer
203
+
204
  elif model_name == "PaddleOCR-VL":
205
  processor = processor_p
206
  model = model_p
207
 
208
+ # --- CORRECTED LOGIC FOR PADDLEOCR-VL ---
209
+ # It expects a simple string content, not a list of dicts.
210
+ # The user's input `text` should be one of the specific prompts.
211
  messages = [{"role": "user", "content": text}]
212
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
213
+
214
  inputs = processor(text=[prompt_full], images=[image], return_tensors="pt").to(device)
215
 
 
216
  generation_kwargs = {
217
  **inputs,
218
  "max_new_tokens": max_new_tokens,
219
+ "do_sample": False, # As per the reference script for best results
220
  "use_cache": True,
221
  }
222
+
223
  with torch.inference_mode():
224
  generated_ids = model.generate(**generation_kwargs)
225
+
226
  resp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
227
+ # Extract only the model's answer, excluding the prompt
228
  answer = resp.split(prompt_full)[-1].strip()
229
  yield answer, answer
230
 
 
232
  yield "Invalid model selected.", "Invalid model selected."
233
  return
234
 
235
+ # Define examples for image inference, updated for both models
 
236
  image_examples = [
237
  ["OCR:", "images/ocr.png"],
238
  ["Table Recognition:", "images/4.png"],
239
+ ["Extract the content of this invoice.", "images/0.png"]
240
  ]
241
 
242
 
 
245
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
246
  with gr.Row():
247
  with gr.Column(scale=2):
248
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter query. For PaddleOCR, use 'OCR:', 'Table Recognition:', etc.")
249
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
250
 
251
  image_submit = gr.Button("Submit", variant="primary")