iammraat commited on
Commit
e5d3222
·
verified ·
1 Parent(s): 8414f8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py CHANGED
@@ -566,5 +566,117 @@ demo = gr.Interface(
566
  allow_flagging="never"
567
  )
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  if __name__ == "__main__":
570
  demo.launch()
 
566
  allow_flagging="never"
567
  )
568
 
569
+ if __name__ == "__main__":
570
+ demo.launch()# app.py (fixed version)
571
+ import gradio as gr
572
+ from ultralytics import YOLO
573
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
574
+ from PIL import Image
575
+ import torch
576
+ import numpy as np
577
+
578
+ # Load local models
579
+ region_model = YOLO("regions.pt")
580
+ line_model = YOLO("lines.pt")
581
+
582
+ # TrOCR
583
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
584
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
585
+
586
+ # Move to GPU if available
587
+ device = "cuda" if torch.cuda.is_available() else "cpu"
588
+ model.to(device)
589
+
590
+ def get_crop(image: Image.Image, result, idx: int, padding: int = 15):
591
+ img_np = np.array(image)
592
+
593
+ if result.masks is not None:
594
+ mask = result.masks.data[idx].cpu().numpy()
595
+ mask_bool = mask > 0.5
596
+
597
+ ys, xs = np.where(mask_bool)
598
+ if len(ys) == 0:
599
+ return None
600
+
601
+ y_min, y_max = ys.min(), ys.max()
602
+ x_min, x_max = xs.min(), xs.max()
603
+
604
+ y_min = max(0, y_min - padding)
605
+ y_max = min(img_np.shape[0], y_max + padding + 1)
606
+ x_min = max(0, x_min - padding)
607
+ x_max = min(img_np.shape[1], x_max + padding + 1)
608
+
609
+ crop = img_np[y_min:y_max, x_min:x_max]
610
+ mask_crop = mask_bool[y_min:y_max, x_min:x_max]
611
+
612
+ crop[~mask_crop] = 255
613
+
614
+ return Image.fromarray(crop)
615
+
616
+ else:
617
+ xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int)
618
+ x1, y1, x2, y2 = xyxy
619
+ x1 = max(0, x1 - padding)
620
+ y1 = max(0, y1 - padding)
621
+ x2 = min(image.width, x2 + padding)
622
+ y2 = min(image.height, y2 + padding)
623
+ return image.crop((x1, y1, x2, y2))
624
+
625
+ def process_image(image: Image.Image):
626
+ results = region_model(image)
627
+ region_result = results[0]
628
+
629
+ if region_result.boxes is None or len(region_result.boxes) == 0:
630
+ return "No text regions detected."
631
+
632
+ regions_with_pos = []
633
+ for i in range(len(region_result.boxes)):
634
+ y1 = region_result.boxes.xyxy[i][1].item()
635
+ crop = get_crop(image, region_result, i, padding=20)
636
+ if crop:
637
+ regions_with_pos.append((y1, crop))
638
+
639
+ regions_with_pos.sort(key=lambda x: x[0])
640
+
641
+ full_text_parts = []
642
+
643
+ for _, region_crop in regions_with_pos:
644
+ line_results = line_model(region_crop)
645
+ line_result = line_results[0]
646
+
647
+ if line_result.boxes is None or len(line_result.boxes) == 0:
648
+ continue
649
+
650
+ lines_with_pos = []
651
+ for j in range(len(line_result.boxes)):
652
+ rel_y1 = line_result.boxes.xyxy[j][1].item()
653
+ rel_x1 = line_result.boxes.xyxy[j][0].item()
654
+ line_crop = get_crop(region_crop, line_result, j, padding=15)
655
+
656
+ if line_crop is None:
657
+ continue
658
+
659
+ pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device)
660
+ generated_ids = model.generate(pixel_values)
661
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
662
+
663
+ lines_with_pos.append((rel_y1, rel_x1, text))
664
+
665
+ lines_with_pos.sort(key=lambda x: (x[0], x[1]))
666
+ region_text = "\n".join([item[2] for item in lines_with_pos])
667
+ full_text_parts.append(region_text)
668
+
669
+ return "\n\n".join(full_text_parts) if full_text_parts else "No text recognized."
670
+
671
+ # Gradio interface (fixed: use flagging_mode instead of allow_flagging)
672
+ demo = gr.Interface(
673
+ fn=process_image,
674
+ inputs=gr.Image(type="pil", label="Upload handwritten document"),
675
+ outputs=gr.Textbox(label="Recognized Text"),
676
+ title="Handwritten Text Recognition (YOLO regions/lines + TrOCR)",
677
+ description="Uses your local regions.pt and lines.pt (same as Riksarkivet demo) with precise mask-based cropping.",
678
+ flagging_mode="never" # ← fixed: changed from allow_flagging to flagging_mode
679
+ )
680
+
681
  if __name__ == "__main__":
682
  demo.launch()