prithivMLmods commited on
Commit
b4d0ff9
·
verified ·
1 Parent(s): 1b38818

update app

Browse files
Files changed (1) hide show
  1. app.py +99 -62
app.py CHANGED
@@ -138,7 +138,7 @@ def process_document_stream(
138
  repetition_penalty: float
139
  ):
140
  """
141
- Main function that handles model inference for general OCR.
142
  """
143
  if image is None:
144
  yield "Please upload an image.", ""
@@ -152,11 +152,9 @@ def process_document_stream(
152
  original_width, original_height = image.size
153
  new_width = int(original_width * image_scale_factor)
154
  new_height = int(original_height * image_scale_factor)
155
- print(f"Scaling image from {image.size} to ({new_width}, {new_height}) with factor {image_scale_factor}.")
156
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
157
  except Exception as e:
158
  print(f"Error during image scaling: {e}")
159
- pass
160
 
161
  temp_image_path = None
162
  try:
@@ -171,11 +169,8 @@ def process_document_stream(
171
  messages = [{'role': 'user', 'content': content}]
172
 
173
  generation_config = {
174
- 'max_new_tokens': max_new_tokens,
175
- 'repetition_penalty': repetition_penalty,
176
- 'temperature': temperature,
177
- 'top_p': top_p,
178
- 'top_k': top_k,
179
  'do_sample': True if temperature > 0 else False
180
  }
181
 
@@ -189,63 +184,93 @@ def process_document_stream(
189
  if temp_image_path and os.path.exists(temp_image_path):
190
  os.remove(temp_image_path)
191
 
192
- # --- Bounding Box Extraction Logic ---
193
  @spaces.GPU
194
- def extract_text_with_coordinates(image: Image.Image):
 
 
 
 
 
 
 
 
195
  """
196
- Runs the model with a specific prompt to get OCR and bounding boxes,
197
- then processes the output to create a visualization.
198
  """
199
  if image is None:
200
- raise gr.Error("Please upload an image first in the main tab.")
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- prompt = "Please perform OCR on the image and provide the bounding box for each recognized text line. The format should be 'text<box>x1, y1, x2, y2</box>'."
203
  temp_image_path = None
204
  try:
205
  temp_dir = tempfile.gettempdir()
206
  temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
207
  image.save(temp_image_path)
208
 
209
- content = [dict(type='image', image=temp_image_path), dict(type='text', text=prompt)]
 
 
 
210
  messages = [{'role': 'user', 'content': content}]
211
- generation_config = {'max_new_tokens': 4096}
 
 
 
 
 
212
 
213
  response = model.chat(messages, tokenizer, image_processor, generation_config)
214
 
215
- original_width, original_height = image.size
 
 
 
 
216
 
217
- # Regex to find coordinates inside <box> tags
218
- pattern_coords = r"<box>(\d+,\s*\d+,\s*\d+,\s*\d+)</box>"
219
- # Regex to split the string by the full box tag to isolate text
220
- pattern_splitter = r"<box>\d+,\s*\d+,\s*\d+,\s*\d+</box>"
221
 
222
- bboxs_raw = re.findall(pattern_coords, response)
223
- lines = [line.strip() for line in re.split(pattern_splitter, response) if line.strip()]
224
-
225
- num_items = min(len(lines), len(bboxs_raw))
226
- vis_image = image.copy()
227
  draw = ImageDraw.Draw(vis_image)
228
- output_text = ""
229
 
230
- for i in range(num_items):
231
- line_text = lines[i]
232
- box_coords = [int(c.strip()) for c in bboxs_raw[i].split(',')]
 
233
 
234
- if len(box_coords) == 4:
235
- x0, y0, x1, y1 = box_coords
236
-
237
- # Scale coordinates from the model's 1000px basis to the original image size
238
- x0_s = int(x0 * original_width / 1000)
239
- y0_s = int(y0 * original_height / 1000)
240
- x1_s = int(x1 * original_width / 1000)
241
- y1_s = int(y1 * original_height / 1000)
 
 
 
 
 
 
242
 
243
- draw.rectangle([x0_s, y0_s, x1_s, y1_s], outline="red", width=2)
244
-
245
- # Format output as a polygon (quadrilateral) and the extracted text
246
- output_text += f"{x0_s},{y0_s},{x1_s},{y0_s},{x1_s},{y1_s},{x0_s},{y1_s},{line_text}\n"
247
 
248
- return output_text.strip(), vis_image
249
 
250
  except Exception as e:
251
  traceback.print_exc()
@@ -277,12 +302,19 @@ def create_gradio_interface():
277
  with gr.Row():
278
  # Left Column (Inputs)
279
  with gr.Column(scale=1):
280
- gr.Textbox(label="Model in Use ⚡", value="tencent/POINTS-Reader", interactive=False)
281
- prompt_input = gr.Textbox(label="Query Input", placeholder="✦︎ Enter the prompt", value="Perform OCR on the image precisely.")
 
 
 
 
282
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
283
 
284
  with gr.Accordion("Advanced Settings", open=False):
285
- image_scale_factor = gr.Slider(minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Image Upscale Factor", info="Increases image size before processing. Can improve OCR on small text.")
 
 
 
286
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
287
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.7)
288
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.8)
@@ -302,23 +334,27 @@ def create_gradio_interface():
302
  with gr.Column(scale=2):
303
  with gr.Tabs() as tabs:
304
  with gr.Tab("📝 Extracted Content"):
305
- raw_output_stream = gr.Textbox(label="Raw Model Output (max T ≤ 120s)", interactive=False, lines=15, show_copy_button=True)
306
  with gr.Row():
307
- examples = gr.Examples(examples=["examples/1.jpeg", "examples/2.jpeg", "examples/3.jpeg", "examples/4.jpeg", "examples/5.jpeg"], inputs=image_input, label="Examples")
 
 
 
308
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/POINTS-Reader-OCR/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
309
 
310
  with gr.Tab("📰 README.md"):
311
  with gr.Accordion("(Result.md)", open=True):
312
  markdown_output = gr.Markdown()
313
-
314
- with gr.Tab("Bounding Boxes"):
315
- gr.Markdown("Click the button to extract text and visualize its location on the image. This uses a specialized prompt to get coordinates from the model.")
 
316
  with gr.Row():
317
- with gr.Column(scale=1):
318
- ocr_button = gr.Button("🔍 Extract Text with Coordinates", variant="primary")
319
- ocr_text = gr.Textbox(label="Extracted Text with Coordinates", info="Format: x1,y1,x2,y2,x3,y3,x4,y4,text", lines=15, show_copy_button=True)
320
- with gr.Column(scale=1):
321
- ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
322
 
323
  with gr.Tab("📋 PDF Preview"):
324
  generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
@@ -326,22 +362,23 @@ def create_gradio_interface():
326
  pdf_preview_gallery = gr.Gallery(label="PDF Page Preview", show_label=True, elem_id="gallery", columns=2, object_fit="contain", height="auto")
327
 
328
  # Event Handlers
 
 
329
  def clear_all_outputs():
330
- # Clear all input and output fields across all tabs
331
  return None, "", "Raw output will appear here.", "", None, None, "", None
332
 
333
  process_btn.click(
334
  fn=process_document_stream,
335
- inputs=[image_input, prompt_input, image_scale_factor, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
336
  outputs=[raw_output_stream, markdown_output]
337
  )
338
-
339
  ocr_button.click(
340
- fn=extract_text_with_coordinates,
341
- inputs=[image_input],
342
  outputs=[ocr_text, ocr_vis]
343
  )
344
-
345
  generate_pdf_btn.click(
346
  fn=generate_and_preview_pdf,
347
  inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],
 
138
  repetition_penalty: float
139
  ):
140
  """
141
+ Main function for standard OCR, handles model inference using tencent/POINTS-Reader.
142
  """
143
  if image is None:
144
  yield "Please upload an image.", ""
 
152
  original_width, original_height = image.size
153
  new_width = int(original_width * image_scale_factor)
154
  new_height = int(original_height * image_scale_factor)
 
155
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
156
  except Exception as e:
157
  print(f"Error during image scaling: {e}")
 
158
 
159
  temp_image_path = None
160
  try:
 
169
  messages = [{'role': 'user', 'content': content}]
170
 
171
  generation_config = {
172
+ 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty,
173
+ 'temperature': temperature, 'top_p': top_p, 'top_k': top_k,
 
 
 
174
  'do_sample': True if temperature > 0 else False
175
  }
176
 
 
184
  if temp_image_path and os.path.exists(temp_image_path):
185
  os.remove(temp_image_path)
186
 
 
187
  @spaces.GPU
188
+ def extract_text_with_boxes(
189
+ image: Image.Image,
190
+ image_scale_factor: float,
191
+ max_new_tokens: int,
192
+ temperature: float,
193
+ top_p: float,
194
+ top_k: int,
195
+ repetition_penalty: float
196
+ ):
197
  """
198
+ Processes an image to extract text and bounding boxes, returning the processed text and a visualization.
 
199
  """
200
  if image is None:
201
+ raise gr.Error("Please upload an image first.")
202
+
203
+ original_image = image.copy() # Keep a copy of the original for visualization
204
+ prompt_for_boxes = "Perform OCR on the image. For each detected line of text, provide its bounding box in the format <box>x_min,y_min,x_max,y_max</box> followed by the text."
205
+
206
+ if image_scale_factor > 1.0:
207
+ try:
208
+ original_width, original_height = image.size
209
+ new_width = int(original_width * image_scale_factor)
210
+ new_height = int(original_height * image_scale_factor)
211
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
212
+ except Exception as e:
213
+ print(f"Error during image scaling: {e}")
214
 
 
215
  temp_image_path = None
216
  try:
217
  temp_dir = tempfile.gettempdir()
218
  temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
219
  image.save(temp_image_path)
220
 
221
+ content = [
222
+ dict(type='image', image=temp_image_path),
223
+ dict(type='text', text=prompt_for_boxes)
224
+ ]
225
  messages = [{'role': 'user', 'content': content}]
226
+
227
+ generation_config = {
228
+ 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty,
229
+ 'temperature': temperature, 'top_p': top_p, 'top_k': top_k,
230
+ 'do_sample': True if temperature > 0 else False
231
+ }
232
 
233
  response = model.chat(messages, tokenizer, image_processor, generation_config)
234
 
235
+ # Post-process to extract boxes and draw them
236
+ original_width, original_height = original_image.size
237
+ # The model's coordinates are normalized to a 1000x1000 canvas
238
+ scale_width = original_width / 1000.0
239
+ scale_height = original_height / 1000.0
240
 
241
+ pattern = r"<box>(\d+,\d+,\d+,\d+)</box>\s*(.*?)\s*(?=<box>|$)"
242
+ matches = re.findall(pattern, response, re.DOTALL)
 
 
243
 
244
+ formatted_output = []
245
+ vis_image = original_image.copy()
 
 
 
246
  draw = ImageDraw.Draw(vis_image)
 
247
 
248
+ for box_str, text in matches:
249
+ text = text.strip()
250
+ if not text:
251
+ continue
252
 
253
+ try:
254
+ coords = [int(c.strip()) for c in box_str.split(',')]
255
+ x0, y0, x1, y1 = coords
256
+
257
+ if x0 >= x1 or y0 >= y1:
258
+ continue
259
+
260
+ scaled_poly = [
261
+ int(x0 * scale_width), int(y0 * scale_height),
262
+ int(x1 * scale_width), int(y0 * scale_height),
263
+ int(x1 * scale_width), int(y1 * scale_height),
264
+ int(x0 * scale_width), int(y1 * scale_height)
265
+ ]
266
+ draw.polygon(scaled_poly, outline="red", width=3)
267
 
268
+ formatted_line = f"{','.join(map(str, scaled_poly))},{text}"
269
+ formatted_output.append(formatted_line)
270
+ except Exception:
271
+ continue
272
 
273
+ return "\n".join(formatted_output), vis_image
274
 
275
  except Exception as e:
276
  traceback.print_exc()
 
302
  with gr.Row():
303
  # Left Column (Inputs)
304
  with gr.Column(scale=1):
305
+ gr.Textbox(
306
+ label="Model in Use ⚡", value="tencent/POINTS-Reader", interactive=False
307
+ )
308
+ prompt_input = gr.Textbox(
309
+ label="Query Input", placeholder="✦︎ Enter the prompt", value="Perform OCR on the image precisely."
310
+ )
311
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
312
 
313
  with gr.Accordion("Advanced Settings", open=False):
314
+ image_scale_factor = gr.Slider(
315
+ minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Image Upscale Factor",
316
+ info="Increases image size before processing. Can improve OCR on small text. Default: 1.0 (no change)."
317
+ )
318
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
319
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.7)
320
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.8)
 
334
  with gr.Column(scale=2):
335
  with gr.Tabs() as tabs:
336
  with gr.Tab("📝 Extracted Content"):
337
+ raw_output_stream = gr.Textbox(label="Raw Model Output (max T ≤ 120s)", interactive=False, lines=20, show_copy_button=True)
338
  with gr.Row():
339
+ examples = gr.Examples(
340
+ examples=["examples/1.jpeg", "examples/2.jpeg", "examples/3.jpeg", "examples/4.jpeg", "examples/5.jpeg"],
341
+ inputs=image_input, label="Examples"
342
+ )
343
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/POINTS-Reader-OCR/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
344
 
345
  with gr.Tab("📰 README.md"):
346
  with gr.Accordion("(Result.md)", open=True):
347
  markdown_output = gr.Markdown()
348
+
349
+ # --- NEW TAB FOR BOUNDING BOXES ---
350
+ with gr.Tab("🖼️ Bounding Boxes"):
351
+ ocr_button = gr.Button("Extract Text with Coordinates", variant="primary")
352
  with gr.Row():
353
+ ocr_text = gr.Textbox(
354
+ label="Extracted Text with Polygon Coordinates", lines=15, show_copy_button=True, scale=1
355
+ )
356
+ ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)", scale=2)
357
+ # --- END NEW TAB ---
358
 
359
  with gr.Tab("📋 PDF Preview"):
360
  generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
 
362
  pdf_preview_gallery = gr.Gallery(label="PDF Page Preview", show_label=True, elem_id="gallery", columns=2, object_fit="contain", height="auto")
363
 
364
  # Event Handlers
365
+ advanced_settings = [image_scale_factor, max_new_tokens, temperature, top_p, top_k, repetition_penalty]
366
+
367
  def clear_all_outputs():
 
368
  return None, "", "Raw output will appear here.", "", None, None, "", None
369
 
370
  process_btn.click(
371
  fn=process_document_stream,
372
+ inputs=[image_input, prompt_input] + advanced_settings,
373
  outputs=[raw_output_stream, markdown_output]
374
  )
375
+
376
  ocr_button.click(
377
+ fn=extract_text_with_boxes,
378
+ inputs=[image_input] + advanced_settings,
379
  outputs=[ocr_text, ocr_vis]
380
  )
381
+
382
  generate_pdf_btn.click(
383
  fn=generate_and_preview_pdf,
384
  inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],