iammraat commited on
Commit
d07b368
·
verified ·
1 Parent(s): dd74e1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -83
app.py CHANGED
@@ -311,7 +311,6 @@
311
 
312
 
313
 
314
-
315
  import gradio as gr
316
  import torch
317
  import numpy as np
@@ -320,142 +319,115 @@ from PIL import Image
320
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
321
  from craft_text_detector import Craft
322
 
323
- # --- 1. THE SAFE PATCH ---
324
- # We override the library's internal function with a safe, pure-Python version.
325
- # This prevents the "inhomogeneous shape" crash AND ensures scaling is applied.
326
- import craft_text_detector.craft_utils as craft_utils_module
327
-
328
- def safe_adjustResultCoordinates(polys, ratio_w, ratio_h):
329
- if not polys:
330
- return []
331
-
332
- adjusted_polys = []
333
- for poly in polys:
334
- # Check 1: Must be a list or array
335
- if poly is None or len(poly) == 0:
336
- continue
337
-
338
- # Check 2: Convert to numpy safely
339
- try:
340
- p = np.array(poly)
341
- # Must have shape (N, 2) where N >= 3 (a polygon)
342
- # If it's a 1D line or a dot, it's noise.
343
- if p.ndim != 2 or p.shape[1] != 2 or p.shape[0] < 3:
344
- continue
345
- except Exception:
346
- continue
347
-
348
- # Check 3: Apply scaling (The Fix for Tiny Boxes)
349
- # We multiply the coordinates by the ratio provided by the library
350
- p = p.astype(np.float32)
351
- p[:, 0] *= ratio_w
352
- p[:, 1] *= ratio_h
353
-
354
- adjusted_polys.append(p)
355
-
356
- return adjusted_polys
357
-
358
- # Apply the patch
359
- craft_utils_module.adjustResultCoordinates = safe_adjustResultCoordinates
360
- # -------------------------
361
 
362
- # 2. LOAD MODELS
363
  print("Loading TrOCR...")
364
- device = "cuda" if torch.cuda.is_available() else "cpu"
365
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
366
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten').to(device).eval()
367
 
368
  print("Loading CRAFT...")
369
- # crop_type="box" gives us clean rectangles which are better for OCR than polygons
370
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
371
 
372
- # 3. HELPER: Sort Boxes
373
  def get_sorted_boxes(boxes):
374
- # Sorts text boxes from top to bottom, then left to right
375
  if not boxes: return []
376
  items = []
377
  for box in boxes:
378
- # Calculate center y and center x
379
  cy = np.mean(box[:, 1])
380
  cx = np.mean(box[:, 0])
381
  items.append((cy, cx, box))
382
 
383
- # Sort by line (Y coordinate / 50px tolerance) then by X position
384
- items.sort(key=lambda x: (int(x[0] // 50), x[1]))
385
  return [x[2] for x in items]
386
 
387
- # 4. MAIN PIPELINE
388
  def process_image(image):
389
  if image is None:
390
  return None, "Please upload an image."
 
 
 
 
 
 
 
391
 
392
- # Convert to RGB (standard format)
393
- # We DO NOT resize the image manually here. We pass the full resolution
394
- # so the coordinates match the display image 1:1.
395
- image_np = np.array(image.convert("RGB"))
396
 
397
- # Run Detection
398
- # The library handles internal resizing and passes the correct ratios
399
- # to our 'safe_adjustResultCoordinates' patch above.
400
- prediction = craft.detect_text(image_np)
 
 
 
 
401
  boxes = prediction.get("boxes", [])
402
 
403
  if not boxes:
404
- return image, "No text detected."
405
-
 
406
  sorted_boxes = get_sorted_boxes(boxes)
407
- annotated_img = image_np.copy()
408
  results = []
409
 
410
  for box in sorted_boxes:
411
- # Convert to Integer for safe drawing/cropping
412
- box = box.astype(np.int32)
413
 
414
- # Draw box (Blue, thickness 4)
415
- cv2.polylines(annotated_img, [box], True, (255, 0, 0), 4)
416
 
417
- # Get cropping coordinates
418
- x_min = max(0, np.min(box[:, 0]))
419
- x_max = min(image_np.shape[1], np.max(box[:, 0]))
420
- y_min = max(0, np.min(box[:, 1]))
421
- y_max = min(image_np.shape[0], np.max(box[:, 1]))
422
 
423
- # NOISE FILTER: Skip boxes that are too small (e.g., specks of dust)
424
- # This prevents the "a b c d" garbage output
425
- if (x_max - x_min) < 20 or (y_max - y_min) < 10:
426
  continue
427
 
428
- # Crop and Recognize
429
- crop = image_np[y_min:y_max, x_min:x_max]
430
  if crop.size == 0: continue
431
 
432
  pil_crop = Image.fromarray(crop)
433
 
 
434
  with torch.no_grad():
435
  pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
436
  generated_ids = model.generate(pixel_values)
437
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
438
 
439
- if text.strip() and len(text.strip()) > 1:
440
  results.append(text)
441
 
442
  full_text = "\n".join(results)
 
 
443
  return Image.fromarray(annotated_img), full_text
444
 
445
- # 5. UI SETUP
446
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
447
- gr.Markdown("# 📝 Handwritten Document OCR")
448
 
449
  with gr.Row():
450
- with gr.Column(scale=1):
451
- input_img = gr.Image(type="pil", label="Upload Document")
452
- run_btn = gr.Button("Transcribe", variant="primary")
453
 
454
- with gr.Column(scale=1):
 
455
  output_img = gr.Image(label="Detected Regions")
456
- output_txt = gr.Textbox(label="Recognized Text", lines=20, show_copy_button=True)
457
 
458
- run_btn.click(process_image, input_img, [output_img, output_txt])
459
 
460
  if __name__ == "__main__":
461
  demo.launch()
 
311
 
312
 
313
 
 
314
  import gradio as gr
315
  import torch
316
  import numpy as np
 
319
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
320
  from craft_text_detector import Craft
321
 
322
+ # --- SETUP ---
323
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
 
325
  print("Loading TrOCR...")
 
326
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
327
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten').to(device).eval()
328
 
329
  print("Loading CRAFT...")
330
+ # We use crop_type="box" to get standard rectangles
331
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
332
 
 
333
  def get_sorted_boxes(boxes):
334
+ """Sorts boxes top-to-bottom, then left-to-right."""
335
  if not boxes: return []
336
  items = []
337
  for box in boxes:
338
+ # Calculate center y and x
339
  cy = np.mean(box[:, 1])
340
  cx = np.mean(box[:, 0])
341
  items.append((cy, cx, box))
342
 
343
+ # Sort by Y (lines) then X
344
+ items.sort(key=lambda x: (int(x[0] // 40), x[1]))
345
  return [x[2] for x in items]
346
 
 
347
  def process_image(image):
348
  if image is None:
349
  return None, "Please upload an image."
350
+
351
+ # 1. UNIFIED RESIZING (The Fix)
352
+ # We resize the input image to 1280px width immediately.
353
+ # We will use this SINGLE image for detection, cropping, and display.
354
+ target_width = 1280
355
+ w_percent = (target_width / float(image.size[0]))
356
+ h_size = int((float(image.size[1]) * float(w_percent)))
357
 
358
+ # High-quality resize
359
+ working_image = image.resize((target_width, h_size), Image.Resampling.LANCZOS)
 
 
360
 
361
+ # Convert to Numpy for OpenCV/CRAFT
362
+ # This is the ONLY image variable we will use from now on.
363
+ img_np = np.array(working_image.convert("RGB"))
364
+
365
+ # 2. DETECT
366
+ # Since our image is 1280px, and CRAFT defaults to 1280px canvas,
367
+ # the internal scaling ratio will be 1.0. Coordinates will match exactly.
368
+ prediction = craft.detect_text(img_np)
369
  boxes = prediction.get("boxes", [])
370
 
371
  if not boxes:
372
+ return working_image, "No text detected."
373
+
374
+ # 3. PROCESS & RECOGNIZE
375
  sorted_boxes = get_sorted_boxes(boxes)
376
+ annotated_img = img_np.copy()
377
  results = []
378
 
379
  for box in sorted_boxes:
380
+ # box is a list of points, convert to numpy int
381
+ box_np = np.array(box).astype(np.int32)
382
 
383
+ # Draw on the WORKING image
384
+ cv2.polylines(annotated_img, [box_np], True, (255, 0, 0), 3)
385
 
386
+ # Get Crop Coordinates
387
+ x_min = max(0, np.min(box_np[:, 0]))
388
+ x_max = min(img_np.shape[1], np.max(box_np[:, 0]))
389
+ y_min = max(0, np.min(box_np[:, 1]))
390
+ y_max = min(img_np.shape[0], np.max(box_np[:, 1]))
391
 
392
+ # Filter noise (tiny specks)
393
+ if (x_max - x_min) < 15 or (y_max - y_min) < 10:
 
394
  continue
395
 
396
+ # Crop from the WORKING image
397
+ crop = img_np[y_min:y_max, x_min:x_max]
398
  if crop.size == 0: continue
399
 
400
  pil_crop = Image.fromarray(crop)
401
 
402
+ # TrOCR Inference
403
  with torch.no_grad():
404
  pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
405
  generated_ids = model.generate(pixel_values)
406
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
407
 
408
+ if text.strip():
409
  results.append(text)
410
 
411
  full_text = "\n".join(results)
412
+
413
+ # Return the annotated WORKING image
414
  return Image.fromarray(annotated_img), full_text
415
 
416
+ # --- UI ---
417
+ with gr.Blocks(title="Handwritten OCR") as demo:
418
+ gr.Markdown("## 📝 Robust Handwritten OCR")
419
 
420
  with gr.Row():
421
+ with gr.Column():
422
+ input_img = gr.Image(type="pil", label="Upload Image")
423
+ btn = gr.Button("Extract Text", variant="primary")
424
 
425
+ with gr.Column():
426
+ # This output image will be the 1280px version we used for processing
427
  output_img = gr.Image(label="Detected Regions")
428
+ output_txt = gr.Textbox(label="Result", lines=20)
429
 
430
+ btn.click(process_image, input_img, [output_img, output_txt])
431
 
432
  if __name__ == "__main__":
433
  demo.launch()