prithivMLmods commited on
Commit
0e52c4c
·
verified ·
1 Parent(s): 5f43587

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -133,11 +133,13 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
133
  ).to(device).eval()
134
 
135
  # Load PaddleOCR-VL
136
- MODEL_ID_P = "strangervisionhf/paddle"
137
- processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
 
138
  model_p = AutoModelForCausalLM.from_pretrained(
139
  MODEL_ID_P,
140
  trust_remote_code=True,
 
141
  torch_dtype=torch.float16
142
  ).to(device).eval()
143
 
@@ -157,7 +159,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
157
  if model_name == "Nanonets-OCR2-3B":
158
  processor = processor_v
159
  model = model_v
160
-
 
161
  messages = [{
162
  "role": "user",
163
  "content": [
@@ -172,7 +175,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
172
  images=[image],
173
  return_tensors="pt",
174
  padding=True).to(device)
175
-
176
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
177
  generation_kwargs = {
178
  **inputs,
@@ -196,42 +199,39 @@ def generate_image(model_name: str, text: str, image: Image.Image,
196
  elif model_name == "PaddleOCR-VL":
197
  processor = processor_p
198
  model = model_p
199
-
200
- # FIX: PaddleOCR-VL expects a simple string content, not a list of dicts.
201
  messages = [{"role": "user", "content": text}]
202
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
203
 
204
- inputs = processor(
205
- text=[prompt_full],
206
- images=[image],
207
- return_tensors="pt"
208
- ).to(device)
209
-
210
  generation_kwargs = {
211
  **inputs,
212
  "max_new_tokens": max_new_tokens,
213
- "do_sample": False,
214
  "use_cache": True,
215
  }
216
 
217
  with torch.inference_mode():
218
  generated_ids = model.generate(**generation_kwargs)
219
-
220
  resp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
221
- # Clean the output by removing the prompt
222
  answer = resp.split(prompt_full)[-1].strip()
223
  yield answer, answer
224
-
225
  else:
226
  yield "Invalid model selected.", "Invalid model selected."
227
  return
228
 
229
 
230
- # Define examples for image inference
231
  image_examples = [
232
- ["Extract the full page.", "images/ocr.png"],
233
- ["OCR:", "images/4.png"], # Example prompt for PaddleOCR-VL
234
- ["Table Recognition:", "images/0.png"] # Example prompt for PaddleOCR-VL
235
  ]
236
 
237
 
@@ -240,7 +240,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
240
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
241
  with gr.Row():
242
  with gr.Column(scale=2):
243
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here... (e.g., 'OCR:')")
244
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
245
 
246
  image_submit = gr.Button("Submit", variant="primary")
 
133
  ).to(device).eval()
134
 
135
  # Load PaddleOCR-VL
136
+ MODEL_ID_P = "PaddlePaddle/PaddleOCR-VL"
137
+ SUBFOLDER_P = "PaddleOCR-VL-0.9B"
138
+ processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True, subfolder=SUBFOLDER_P)
139
  model_p = AutoModelForCausalLM.from_pretrained(
140
  MODEL_ID_P,
141
  trust_remote_code=True,
142
+ subfolder=SUBFOLDER_P,
143
  torch_dtype=torch.float16
144
  ).to(device).eval()
145
 
 
159
  if model_name == "Nanonets-OCR2-3B":
160
  processor = processor_v
161
  model = model_v
162
+
163
+ # Nanonets/Qwen-VL format: content is a list of dicts
164
  messages = [{
165
  "role": "user",
166
  "content": [
 
175
  images=[image],
176
  return_tensors="pt",
177
  padding=True).to(device)
178
+
179
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
180
  generation_kwargs = {
181
  **inputs,
 
199
  elif model_name == "PaddleOCR-VL":
200
  processor = processor_p
201
  model = model_p
202
+
203
+ # PaddleOCR-VL format: content is a simple string
204
  messages = [{"role": "user", "content": text}]
205
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
206
+
207
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt").to(device)
208
 
209
+ # Use generation parameters from the reference script for best results
 
 
 
 
 
210
  generation_kwargs = {
211
  **inputs,
212
  "max_new_tokens": max_new_tokens,
213
+ "do_sample": False,
214
  "use_cache": True,
215
  }
216
 
217
  with torch.inference_mode():
218
  generated_ids = model.generate(**generation_kwargs)
219
+
220
  resp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
221
+ # Extract only the generated part of the response
222
  answer = resp.split(prompt_full)[-1].strip()
223
  yield answer, answer
224
+
225
  else:
226
  yield "Invalid model selected.", "Invalid model selected."
227
  return
228
 
229
 
230
+ # Define examples for image inference, tailored for both models
231
  image_examples = [
232
+ ["OCR:", "images/ocr.png"],
233
+ ["Table Recognition:", "images/4.png"],
234
+ ["Extract the content from this image.", "images/0.png"]
235
  ]
236
 
237
 
 
240
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
241
  with gr.Row():
242
  with gr.Column(scale=2):
243
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter query or task (e.g., 'OCR:')")
244
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
245
 
246
  image_submit = gr.Button("Submit", variant="primary")