iammraat commited on
Commit
e408019
·
verified ·
1 Parent(s): adb25fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -32
app.py CHANGED
@@ -314,7 +314,6 @@
314
 
315
 
316
 
317
-
318
  import gradio as gr
319
  import torch
320
  import numpy as np
@@ -323,35 +322,51 @@ from PIL import Image
323
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
324
  from craft_text_detector import Craft
325
 
326
- # --- DEFENSIVE MONKEY PATCH ---
327
  import craft_text_detector.craft_utils as craft_utils_module
328
 
329
  def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h):
330
- # This ensures coordinates are scaled correctly and remain as float32
331
- if not polys: return []
332
  adjusted = []
333
  for poly in polys:
334
- p = np.array(poly).reshape(-1, 2).astype(np.float32)
335
- p[:, 0] *= ratio_w
336
- p[:, 1] *= ratio_h
337
- adjusted.append(p)
 
 
 
 
 
 
 
 
 
 
 
 
338
  return adjusted
339
 
340
  craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
341
- # ------------------------------
342
 
343
  device = "cuda" if torch.cuda.is_available() else "cpu"
344
 
345
  # Load Models
 
346
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
347
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten").to(device).eval()
 
 
348
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
349
 
350
  def process_page(image):
351
- if image is None: return None, "No image."
 
352
 
353
- # 1. FORCE RESIZE for coordinate 1:1 mapping
354
- # We resize to 1280 width to match CRAFT's default canvas
355
  base_width = 1280
356
  w_percent = (base_width / float(image.size[0]))
357
  h_size = int((float(image.size[1]) * float(w_percent)))
@@ -360,18 +375,21 @@ def process_page(image):
360
  img_np = np.array(image.convert("RGB"))
361
 
362
  # 2. DETECT
363
- # Because we resized the image to 1280, ratio_w/h will be ~1.0
364
  prediction = craft.detect_text(img_np)
365
  boxes = prediction.get("boxes", [])
366
 
367
- if not boxes: return image, "No text found."
 
368
 
369
- # 3. SORT (Improved line grouping)
370
- # We group items within 30 pixels of each other vertically as a single 'line'
371
  items = []
372
  for box in boxes:
373
  items.append({'cy': np.mean(box[:, 1]), 'cx': np.mean(box[:, 0]), 'box': box})
374
- items.sort(key=lambda x: (int(x['cy'] // 30), x['cx']))
 
 
 
375
 
376
  annotated = img_np.copy()
377
  full_text = []
@@ -380,37 +398,46 @@ def process_page(image):
380
  for item in items:
381
  box = item['box'].astype(np.int32)
382
 
383
- # Draw on image
384
  cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
385
 
386
- # Crop
387
  x_min, y_min = np.min(box, axis=0)
388
  x_max, y_max = np.max(box, axis=0)
389
 
390
- # Clamp to image boundaries
391
  x_min, y_min = max(0, x_min), max(0, y_min)
392
  x_max, y_max = min(img_np.shape[1], x_max), min(img_np.shape[0], y_max)
393
 
394
- if (x_max - x_min) < 5 or (y_max - y_min) < 5: continue
 
395
 
396
- crop = Image.fromarray(img_np[y_min:y_max, x_min:x_max])
397
 
398
  with torch.no_grad():
399
- pixel_values = processor(images=crop, return_tensors="pt").pixel_values.to(device)
400
  out_ids = model.generate(pixel_values)
401
  txt = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
402
- if txt.strip(): full_text.append(txt)
 
403
 
404
  return Image.fromarray(annotated), " ".join(full_text)
405
 
406
- # UI
407
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
408
- gr.Markdown("# 🚀 Final Fix: Full-Page OCR")
 
 
409
  with gr.Row():
410
- input_i = gr.Image(type="pil", label="Input")
411
- output_i = gr.Image(label="Detections (Scale Fixed)")
412
- output_t = gr.Textbox(label="Result", lines=10)
413
- btn = gr.Button("Transcribe", variant="primary")
414
- btn.click(process_page, input_i, [output_i, output_t])
 
 
 
 
415
 
416
- demo.launch()
 
 
314
 
315
 
316
 
 
317
  import gradio as gr
318
  import torch
319
  import numpy as np
 
322
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
323
  from craft_text_detector import Craft
324
 
325
+ # --- THE FORTIFIED MONKEY PATCH ---
326
  import craft_text_detector.craft_utils as craft_utils_module
327
 
328
  def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h):
329
+ if not polys:
330
+ return []
331
  adjusted = []
332
  for poly in polys:
333
+ try:
334
+ # Convert to numpy array first
335
+ p = np.array(poly)
336
+
337
+ # SANITY CHECK: A coordinate pair needs 2 numbers per point.
338
+ # If the total size is less than 2 or not even, it's noise.
339
+ if p.size < 2 or p.size % 2 != 0:
340
+ continue
341
+
342
+ p = p.reshape(-1, 2).astype(np.float32)
343
+ p[:, 0] *= ratio_w
344
+ p[:, 1] *= ratio_h
345
+ adjusted.append(p)
346
+ except Exception:
347
+ # If any mathematical error occurs for a specific noisy box, skip it
348
+ continue
349
  return adjusted
350
 
351
  craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
352
+ # ----------------------------------
353
 
354
  device = "cuda" if torch.cuda.is_available() else "cpu"
355
 
356
  # Load Models
357
+ print("Loading TrOCR...")
358
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
359
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten").to(device).eval()
360
+
361
+ print("Loading CRAFT...")
362
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
363
 
364
  def process_page(image):
365
+ if image is None:
366
+ return None, "Please upload an image."
367
 
368
+ # 1. FORCE RESIZE to align coordinate systems (Crucial for the 'tiny box' fix)
369
+ # We use 1280px as the standard width
370
  base_width = 1280
371
  w_percent = (base_width / float(image.size[0]))
372
  h_size = int((float(image.size[1]) * float(w_percent)))
 
375
  img_np = np.array(image.convert("RGB"))
376
 
377
  # 2. DETECT
 
378
  prediction = craft.detect_text(img_np)
379
  boxes = prediction.get("boxes", [])
380
 
381
+ if not boxes:
382
+ return image, "No text detected."
383
 
384
+ # 3. SORT (Line-by-line grouping)
385
+ # This logic groups boxes into lines if they overlap vertically
386
  items = []
387
  for box in boxes:
388
  items.append({'cy': np.mean(box[:, 1]), 'cx': np.mean(box[:, 0]), 'box': box})
389
+
390
+ # Sort by Y (approximate lines) then X (left to right)
391
+ # We increase the grouping factor to 40 to handle handwriting slant
392
+ items.sort(key=lambda x: (int(x['cy'] // 40), x['cx']))
393
 
394
  annotated = img_np.copy()
395
  full_text = []
 
398
  for item in items:
399
  box = item['box'].astype(np.int32)
400
 
401
+ # Draw on display image
402
  cv2.polylines(annotated, [box], True, (255, 0, 0), 2)
403
 
404
+ # Crop coordinates
405
  x_min, y_min = np.min(box, axis=0)
406
  x_max, y_max = np.max(box, axis=0)
407
 
408
+ # Clip to image boundaries
409
  x_min, y_min = max(0, x_min), max(0, y_min)
410
  x_max, y_max = min(img_np.shape[1], x_max), min(img_np.shape[0], y_max)
411
 
412
+ if (x_max - x_min) < 10 or (y_max - y_min) < 10:
413
+ continue
414
 
415
+ crop_pil = Image.fromarray(img_np[y_min:y_max, x_min:x_max])
416
 
417
  with torch.no_grad():
418
+ pixel_values = processor(images=crop_pil, return_tensors="pt").pixel_values.to(device)
419
  out_ids = model.generate(pixel_values)
420
  txt = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
421
+ if txt.strip():
422
+ full_text.append(txt)
423
 
424
  return Image.fromarray(annotated), " ".join(full_text)
425
 
426
+ # --- Gradio UI ---
427
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
428
+ gr.Markdown("# ✍️ Full-Page Handwritten Recognition")
429
+ gr.Markdown("Pipeline: **CRAFT** (Detection) ➡️ **TrOCR** (Recognition)")
430
+
431
  with gr.Row():
432
+ with gr.Column():
433
+ input_img = gr.Image(type="pil", label="Step 1: Upload Image")
434
+ btn = gr.Button("Transcribe Page", variant="primary")
435
+
436
+ with gr.Column():
437
+ output_img = gr.Image(label="Step 2: Review Detections")
438
+ output_txt = gr.Textbox(label="Step 3: Extracted Text", lines=12)
439
+
440
+ btn.click(fn=process_page, inputs=input_img, outputs=[output_img, output_txt])
441
 
442
+ if __name__ == "__main__":
443
+ demo.launch()