prithivMLmods commited on
Commit
5f43587
·
verified ·
1 Parent(s): 5780205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -45
app.py CHANGED
@@ -15,8 +15,7 @@ import cv2
15
 
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
- AutoModelForImageTextToText,
19
- AutoModelForCausalLM,# Added for PaddleOCR-VL
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
@@ -133,13 +132,13 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
133
  torch_dtype=torch.float16
134
  ).to(device).eval()
135
 
136
- # Load PaddleOCR-VL [PaddlePaddle/PaddleOCR-VL]
137
  MODEL_ID_P = "strangervisionhf/paddle"
138
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
139
  model_p = AutoModelForCausalLM.from_pretrained(
140
  MODEL_ID_P,
141
  trust_remote_code=True,
142
- torch_dtype=torch.float16,
143
  ).to(device).eval()
144
 
145
 
@@ -158,30 +157,22 @@ def generate_image(model_name: str, text: str, image: Image.Image,
158
  if model_name == "Nanonets-OCR2-3B":
159
  processor = processor_v
160
  model = model_v
161
- elif model_name == "PaddleOCR-VL":
162
- processor = processor_p
163
- model = model_p
164
- else:
165
- yield "Invalid model selected.", "Invalid model selected."
166
- return
167
 
168
- messages = [{
169
- "role": "user",
170
- "content": [
171
- {"type": "image"},
172
- {"type": "text", "text": text},
173
- ]
174
- }]
175
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
176
 
177
- inputs = processor(
178
- text=[prompt_full],
179
- images=[image],
180
- return_tensors="pt",
181
- padding=True).to(device)
182
-
183
- # Nanonets model supports streaming
184
- if model_name == "Nanonets-OCR2-3B":
185
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
186
  generation_kwargs = {
187
  **inputs,
@@ -202,34 +193,45 @@ def generate_image(model_name: str, text: str, image: Image.Image,
202
  time.sleep(0.01)
203
  yield buffer, buffer
204
 
205
- # PaddleOCR-VL does not use a streamer, generate full response
206
  elif model_name == "PaddleOCR-VL":
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  generation_kwargs = {
208
  **inputs,
209
  "max_new_tokens": max_new_tokens,
210
- "do_sample": True,
211
- "temperature": temperature,
212
- "top_p": top_p,
213
- "top_k": top_k,
214
- "repetition_penalty": repetition_penalty,
215
  }
216
- generated_ids = model.generate(**generation_kwargs)
217
- generated_ids_trimmed = [
218
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
219
- ]
220
- output_text = processor.batch_decode(
221
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
222
- )[0]
223
 
224
- output_text = output_text.replace("<|im_end|>", "").strip()
225
- yield output_text, output_text
 
 
 
 
 
 
 
 
 
226
 
227
 
228
  # Define examples for image inference
229
  image_examples = [
230
  ["Extract the full page.", "images/ocr.png"],
231
- ["Extract the content.", "images/4.png"],
232
- ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
233
  ]
234
 
235
 
@@ -238,7 +240,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
238
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
239
  with gr.Row():
240
  with gr.Column(scale=2):
241
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
242
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
243
 
244
  image_submit = gr.Button("Submit", variant="primary")
@@ -256,7 +258,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
256
 
257
  with gr.Column(scale=3):
258
  gr.Markdown("## Output", elem_id="output-title")
259
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
260
  with gr.Accordion("(Result.md)", open=False):
261
  markdown_output = gr.Markdown(label="(Result.Md)")
262
 
 
15
 
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
+ AutoModelForCausalLM, # Added for PaddleOCR-VL
 
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
 
132
  torch_dtype=torch.float16
133
  ).to(device).eval()
134
 
135
+ # Load PaddleOCR-VL
136
  MODEL_ID_P = "strangervisionhf/paddle"
137
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
138
  model_p = AutoModelForCausalLM.from_pretrained(
139
  MODEL_ID_P,
140
  trust_remote_code=True,
141
+ torch_dtype=torch.float16
142
  ).to(device).eval()
143
 
144
 
 
157
  if model_name == "Nanonets-OCR2-3B":
158
  processor = processor_v
159
  model = model_v
 
 
 
 
 
 
160
 
161
+ messages = [{
162
+ "role": "user",
163
+ "content": [
164
+ {"type": "image"},
165
+ {"type": "text", "text": text},
166
+ ]
167
+ }]
168
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
169
 
170
+ inputs = processor(
171
+ text=[prompt_full],
172
+ images=[image],
173
+ return_tensors="pt",
174
+ padding=True).to(device)
175
+
 
 
176
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
177
  generation_kwargs = {
178
  **inputs,
 
193
  time.sleep(0.01)
194
  yield buffer, buffer
195
 
 
196
  elif model_name == "PaddleOCR-VL":
197
+ processor = processor_p
198
+ model = model_p
199
+
200
+ # FIX: PaddleOCR-VL expects a simple string content, not a list of dicts.
201
+ messages = [{"role": "user", "content": text}]
202
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
203
+
204
+ inputs = processor(
205
+ text=[prompt_full],
206
+ images=[image],
207
+ return_tensors="pt"
208
+ ).to(device)
209
+
210
  generation_kwargs = {
211
  **inputs,
212
  "max_new_tokens": max_new_tokens,
213
+ "do_sample": False,
214
+ "use_cache": True,
 
 
 
215
  }
 
 
 
 
 
 
 
216
 
217
+ with torch.inference_mode():
218
+ generated_ids = model.generate(**generation_kwargs)
219
+
220
+ resp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
221
+ # Clean the output by removing the prompt
222
+ answer = resp.split(prompt_full)[-1].strip()
223
+ yield answer, answer
224
+
225
+ else:
226
+ yield "Invalid model selected.", "Invalid model selected."
227
+ return
228
 
229
 
230
  # Define examples for image inference
231
  image_examples = [
232
  ["Extract the full page.", "images/ocr.png"],
233
+ ["OCR:", "images/4.png"], # Example prompt for PaddleOCR-VL
234
+ ["Table Recognition:", "images/0.png"] # Example prompt for PaddleOCR-VL
235
  ]
236
 
237
 
 
240
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
241
  with gr.Row():
242
  with gr.Column(scale=2):
243
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here... (e.g., 'OCR:')")
244
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
245
 
246
  image_submit = gr.Button("Submit", variant="primary")
 
258
 
259
  with gr.Column(scale=3):
260
  gr.Markdown("## Output", elem_id="output-title")
261
+ output = gr.Textbox(label="Raw Output", interactive=False, lines=11, show_copy_button=True)
262
  with gr.Accordion("(Result.md)", open=False):
263
  markdown_output = gr.Markdown(label="(Result.Md)")
264