prithivMLmods commited on
Commit
44691c0
·
verified ·
1 Parent(s): b4d0ff9

update app

Browse files
Files changed (1) hide show
  1. app.py +56 -128
app.py CHANGED
@@ -15,7 +15,7 @@ import tempfile
15
  import gradio as gr
16
  import requests
17
  import torch
18
- from PIL import Image, ImageDraw
19
  import fitz
20
  import numpy as np
21
 
@@ -130,7 +130,7 @@ def generate_and_preview_pdf(image: Image.Image, text_content: str, font_size: i
130
  def process_document_stream(
131
  image: Image.Image,
132
  prompt_input: str,
133
- image_scale_factor: float,
134
  max_new_tokens: int,
135
  temperature: float,
136
  top_p: float,
@@ -138,7 +138,7 @@ def process_document_stream(
138
  repetition_penalty: float
139
  ):
140
  """
141
- Main function for standard OCR, handles model inference using tencent/POINTS-Reader.
142
  """
143
  if image is None:
144
  yield "Please upload an image.", ""
@@ -147,135 +147,66 @@ def process_document_stream(
147
  yield "Please enter a prompt.", ""
148
  return
149
 
 
150
  if image_scale_factor > 1.0:
151
  try:
152
  original_width, original_height = image.size
153
  new_width = int(original_width * image_scale_factor)
154
  new_height = int(original_height * image_scale_factor)
 
 
155
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
156
  except Exception as e:
157
  print(f"Error during image scaling: {e}")
 
 
 
158
 
159
  temp_image_path = None
160
  try:
 
 
161
  temp_dir = tempfile.gettempdir()
162
  temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
163
  image.save(temp_image_path)
164
 
 
165
  content = [
166
  dict(type='image', image=temp_image_path),
167
  dict(type='text', text=prompt_input)
168
  ]
169
- messages = [{'role': 'user', 'content': content}]
 
 
 
 
 
170
 
 
171
  generation_config = {
172
- 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty,
173
- 'temperature': temperature, 'top_p': top_p, 'top_k': top_k,
 
 
 
174
  'do_sample': True if temperature > 0 else False
175
  }
176
 
177
- response = model.chat(messages, tokenizer, image_processor, generation_config)
 
 
 
 
 
 
 
178
  yield response, response
179
 
180
  except Exception as e:
181
  traceback.print_exc()
182
  yield f"An error occurred during processing: {str(e)}", ""
183
  finally:
184
- if temp_image_path and os.path.exists(temp_image_path):
185
- os.remove(temp_image_path)
186
-
187
- @spaces.GPU
188
- def extract_text_with_boxes(
189
- image: Image.Image,
190
- image_scale_factor: float,
191
- max_new_tokens: int,
192
- temperature: float,
193
- top_p: float,
194
- top_k: int,
195
- repetition_penalty: float
196
- ):
197
- """
198
- Processes an image to extract text and bounding boxes, returning the processed text and a visualization.
199
- """
200
- if image is None:
201
- raise gr.Error("Please upload an image first.")
202
-
203
- original_image = image.copy() # Keep a copy of the original for visualization
204
- prompt_for_boxes = "Perform OCR on the image. For each detected line of text, provide its bounding box in the format <box>x_min,y_min,x_max,y_max</box> followed by the text."
205
-
206
- if image_scale_factor > 1.0:
207
- try:
208
- original_width, original_height = image.size
209
- new_width = int(original_width * image_scale_factor)
210
- new_height = int(original_height * image_scale_factor)
211
- image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
212
- except Exception as e:
213
- print(f"Error during image scaling: {e}")
214
-
215
- temp_image_path = None
216
- try:
217
- temp_dir = tempfile.gettempdir()
218
- temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
219
- image.save(temp_image_path)
220
-
221
- content = [
222
- dict(type='image', image=temp_image_path),
223
- dict(type='text', text=prompt_for_boxes)
224
- ]
225
- messages = [{'role': 'user', 'content': content}]
226
-
227
- generation_config = {
228
- 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty,
229
- 'temperature': temperature, 'top_p': top_p, 'top_k': top_k,
230
- 'do_sample': True if temperature > 0 else False
231
- }
232
-
233
- response = model.chat(messages, tokenizer, image_processor, generation_config)
234
-
235
- # Post-process to extract boxes and draw them
236
- original_width, original_height = original_image.size
237
- # The model's coordinates are normalized to a 1000x1000 canvas
238
- scale_width = original_width / 1000.0
239
- scale_height = original_height / 1000.0
240
-
241
- pattern = r"<box>(\d+,\d+,\d+,\d+)</box>\s*(.*?)\s*(?=<box>|$)"
242
- matches = re.findall(pattern, response, re.DOTALL)
243
-
244
- formatted_output = []
245
- vis_image = original_image.copy()
246
- draw = ImageDraw.Draw(vis_image)
247
-
248
- for box_str, text in matches:
249
- text = text.strip()
250
- if not text:
251
- continue
252
-
253
- try:
254
- coords = [int(c.strip()) for c in box_str.split(',')]
255
- x0, y0, x1, y1 = coords
256
-
257
- if x0 >= x1 or y0 >= y1:
258
- continue
259
-
260
- scaled_poly = [
261
- int(x0 * scale_width), int(y0 * scale_height),
262
- int(x1 * scale_width), int(y0 * scale_height),
263
- int(x1 * scale_width), int(y1 * scale_height),
264
- int(x0 * scale_width), int(y1 * scale_height)
265
- ]
266
- draw.polygon(scaled_poly, outline="red", width=3)
267
-
268
- formatted_line = f"{','.join(map(str, scaled_poly))},{text}"
269
- formatted_output.append(formatted_line)
270
- except Exception:
271
- continue
272
-
273
- return "\n".join(formatted_output), vis_image
274
-
275
- except Exception as e:
276
- traceback.print_exc()
277
- return f"An error occurred: {str(e)}", None
278
- finally:
279
  if temp_image_path and os.path.exists(temp_image_path):
280
  os.remove(temp_image_path)
281
 
@@ -303,18 +234,28 @@ def create_gradio_interface():
303
  # Left Column (Inputs)
304
  with gr.Column(scale=1):
305
  gr.Textbox(
306
- label="Model in Use ⚡", value="tencent/POINTS-Reader", interactive=False
 
 
307
  )
308
  prompt_input = gr.Textbox(
309
- label="Query Input", placeholder="✦︎ Enter the prompt", value="Perform OCR on the image precisely."
 
 
310
  )
311
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
312
 
313
  with gr.Accordion("Advanced Settings", open=False):
 
314
  image_scale_factor = gr.Slider(
315
- minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Image Upscale Factor",
 
 
 
 
316
  info="Increases image size before processing. Can improve OCR on small text. Default: 1.0 (no change)."
317
  )
 
318
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
319
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.7)
320
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.8)
@@ -334,10 +275,14 @@ def create_gradio_interface():
334
  with gr.Column(scale=2):
335
  with gr.Tabs() as tabs:
336
  with gr.Tab("📝 Extracted Content"):
337
- raw_output_stream = gr.Textbox(label="Raw Model Output (max T ≤ 120s)", interactive=False, lines=20, show_copy_button=True)
338
  with gr.Row():
339
  examples = gr.Examples(
340
- examples=["examples/1.jpeg", "examples/2.jpeg", "examples/3.jpeg", "examples/4.jpeg", "examples/5.jpeg"],
 
 
 
 
341
  inputs=image_input, label="Examples"
342
  )
343
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/POINTS-Reader-OCR/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
@@ -346,38 +291,21 @@ def create_gradio_interface():
346
  with gr.Accordion("(Result.md)", open=True):
347
  markdown_output = gr.Markdown()
348
 
349
- # --- NEW TAB FOR BOUNDING BOXES ---
350
- with gr.Tab("🖼️ Bounding Boxes"):
351
- ocr_button = gr.Button("Extract Text with Coordinates", variant="primary")
352
- with gr.Row():
353
- ocr_text = gr.Textbox(
354
- label="Extracted Text with Polygon Coordinates", lines=15, show_copy_button=True, scale=1
355
- )
356
- ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)", scale=2)
357
- # --- END NEW TAB ---
358
-
359
  with gr.Tab("📋 PDF Preview"):
360
  generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
361
  pdf_output_file = gr.File(label="Download Generated PDF", interactive=False)
362
  pdf_preview_gallery = gr.Gallery(label="PDF Page Preview", show_label=True, elem_id="gallery", columns=2, object_fit="contain", height="auto")
363
 
364
  # Event Handlers
365
- advanced_settings = [image_scale_factor, max_new_tokens, temperature, top_p, top_k, repetition_penalty]
366
-
367
  def clear_all_outputs():
368
- return None, "", "Raw output will appear here.", "", None, None, "", None
369
 
370
  process_btn.click(
371
  fn=process_document_stream,
372
- inputs=[image_input, prompt_input] + advanced_settings,
 
373
  outputs=[raw_output_stream, markdown_output]
374
  )
375
-
376
- ocr_button.click(
377
- fn=extract_text_with_boxes,
378
- inputs=[image_input] + advanced_settings,
379
- outputs=[ocr_text, ocr_vis]
380
- )
381
 
382
  generate_pdf_btn.click(
383
  fn=generate_and_preview_pdf,
@@ -387,7 +315,7 @@ def create_gradio_interface():
387
 
388
  clear_btn.click(
389
  clear_all_outputs,
390
- outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery, ocr_text, ocr_vis]
391
  )
392
  return demo
393
 
 
15
  import gradio as gr
16
  import requests
17
  import torch
18
+ from PIL import Image
19
  import fitz
20
  import numpy as np
21
 
 
130
  def process_document_stream(
131
  image: Image.Image,
132
  prompt_input: str,
133
+ image_scale_factor: float, # New parameter for image scaling
134
  max_new_tokens: int,
135
  temperature: float,
136
  top_p: float,
 
138
  repetition_penalty: float
139
  ):
140
  """
141
+ Main function that handles model inference using tencent/POINTS-Reader.
142
  """
143
  if image is None:
144
  yield "Please upload an image.", ""
 
147
  yield "Please enter a prompt.", ""
148
  return
149
 
150
+ # --- IMPLEMENTATION: Image Scaling based on user input ---
151
  if image_scale_factor > 1.0:
152
  try:
153
  original_width, original_height = image.size
154
  new_width = int(original_width * image_scale_factor)
155
  new_height = int(original_height * image_scale_factor)
156
+ print(f"Scaling image from {image.size} to ({new_width}, {new_height}) with factor {image_scale_factor}.")
157
+ # Use a high-quality resampling filter for better results
158
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
159
  except Exception as e:
160
  print(f"Error during image scaling: {e}")
161
+ # Continue with the original image if scaling fails
162
+ pass
163
+ # --- END IMPLEMENTATION ---
164
 
165
  temp_image_path = None
166
  try:
167
+ # --- FIX: Save the PIL Image to a temporary file ---
168
+ # The model expects a file path, not a PIL object.
169
  temp_dir = tempfile.gettempdir()
170
  temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
171
  image.save(temp_image_path)
172
 
173
+ # Prepare content for the model using the temporary file path
174
  content = [
175
  dict(type='image', image=temp_image_path),
176
  dict(type='text', text=prompt_input)
177
  ]
178
+ messages = [
179
+ {
180
+ 'role': 'user',
181
+ 'content': content
182
+ }
183
+ ]
184
 
185
+ # Prepare generation configuration from UI inputs
186
  generation_config = {
187
+ 'max_new_tokens': max_new_tokens,
188
+ 'repetition_penalty': repetition_penalty,
189
+ 'temperature': temperature,
190
+ 'top_p': top_p,
191
+ 'top_k': top_k,
192
  'do_sample': True if temperature > 0 else False
193
  }
194
 
195
+ # Run inference
196
+ response = model.chat(
197
+ messages,
198
+ tokenizer,
199
+ image_processor,
200
+ generation_config
201
+ )
202
+ # Yield the full response at once
203
  yield response, response
204
 
205
  except Exception as e:
206
  traceback.print_exc()
207
  yield f"An error occurred during processing: {str(e)}", ""
208
  finally:
209
+ # --- Clean up the temporary image file ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if temp_image_path and os.path.exists(temp_image_path):
211
  os.remove(temp_image_path)
212
 
 
234
  # Left Column (Inputs)
235
  with gr.Column(scale=1):
236
  gr.Textbox(
237
+ label="Model in Use ⚡",
238
+ value="tencent/POINTS-Reader",
239
+ interactive=False
240
  )
241
  prompt_input = gr.Textbox(
242
+ label="Query Input",
243
+ placeholder="✦︎ Enter the prompt",
244
+ value="Perform OCR on the image precisely.",
245
  )
246
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
247
 
248
  with gr.Accordion("Advanced Settings", open=False):
249
+ # --- NEW UI ELEMENT: Image Scaling Slider ---
250
  image_scale_factor = gr.Slider(
251
+ minimum=1.0,
252
+ maximum=3.0,
253
+ value=1.0,
254
+ step=0.1,
255
+ label="Image Upscale Factor",
256
  info="Increases image size before processing. Can improve OCR on small text. Default: 1.0 (no change)."
257
  )
258
+ # --- END NEW UI ELEMENT ---
259
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
260
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.7)
261
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.8)
 
275
  with gr.Column(scale=2):
276
  with gr.Tabs() as tabs:
277
  with gr.Tab("📝 Extracted Content"):
278
+ raw_output_stream = gr.Textbox(label="Raw Model Output (max T ≤ 120s)", interactive=False, lines=15, show_copy_button=True)
279
  with gr.Row():
280
  examples = gr.Examples(
281
+ examples=["examples/1.jpeg",
282
+ "examples/2.jpeg",
283
+ "examples/3.jpeg",
284
+ "examples/4.jpeg",
285
+ "examples/5.jpeg"],
286
  inputs=image_input, label="Examples"
287
  )
288
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/POINTS-Reader-OCR/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
 
291
  with gr.Accordion("(Result.md)", open=True):
292
  markdown_output = gr.Markdown()
293
 
 
 
 
 
 
 
 
 
 
 
294
  with gr.Tab("📋 PDF Preview"):
295
  generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
296
  pdf_output_file = gr.File(label="Download Generated PDF", interactive=False)
297
  pdf_preview_gallery = gr.Gallery(label="PDF Page Preview", show_label=True, elem_id="gallery", columns=2, object_fit="contain", height="auto")
298
 
299
  # Event Handlers
 
 
300
  def clear_all_outputs():
301
+ return None, "", "Raw output will appear here.", "", None, None
302
 
303
  process_btn.click(
304
  fn=process_document_stream,
305
+ # --- UPDATE: Add the new slider to the inputs list ---
306
+ inputs=[image_input, prompt_input, image_scale_factor, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
307
  outputs=[raw_output_stream, markdown_output]
308
  )
 
 
 
 
 
 
309
 
310
  generate_pdf_btn.click(
311
  fn=generate_and_preview_pdf,
 
315
 
316
  clear_btn.click(
317
  clear_all_outputs,
318
+ outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery]
319
  )
320
  return demo
321