iammraat commited on
Commit
4efabae
·
verified ·
1 Parent(s): 0b695b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -144
app.py CHANGED
@@ -441,8 +441,7 @@
441
 
442
 
443
 
444
-
445
- # app.py
446
  import gradio as gr
447
  from ultralytics import YOLO
448
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
@@ -450,28 +449,20 @@ from PIL import Image
450
  import torch
451
  import numpy as np
452
 
453
- # Load local models (your uploaded .pt files)
454
- region_model = YOLO("regions.pt") # ← fixed: local file
455
- line_model = YOLO("lines.pt") # ← fixed: local file
456
 
457
- # TrOCR (you can change to large if you have GPU and want better accuracy)
458
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
459
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
460
 
461
- # Move TrOCR to GPU if available (much faster on paid Spaces)
462
  device = "cuda" if torch.cuda.is_available() else "cpu"
463
  model.to(device)
464
 
465
  def get_crop(image: Image.Image, result, idx: int, padding: int = 15):
466
- """
467
- Crop using segmentation mask if available (much more accurate than boxes),
468
- otherwise fall back to bounding box with padding.
469
- Background outside the mask is forced to white → better for OCR.
470
- """
471
  img_np = np.array(image)
472
 
473
  if result.masks is not None:
474
- # Segmentation model → use mask (this is what the original Riksarkivet demo does)
475
  mask = result.masks.data[idx].cpu().numpy()
476
  mask_bool = mask > 0.5
477
 
@@ -482,147 +473,40 @@ def get_crop(image: Image.Image, result, idx: int, padding: int = 15):
482
  y_min, y_max = ys.min(), ys.max()
483
  x_min, x_max = xs.min(), xs.max()
484
 
485
- # Add padding
486
  y_min = max(0, y_min - padding)
487
  y_max = min(img_np.shape[0], y_max + padding + 1)
488
  x_min = max(0, x_min - padding)
489
  x_max = min(img_np.shape[1], x_max + padding + 1)
490
 
 
 
 
 
491
  crop = img_np[y_min:y_max, x_min:x_max]
492
  mask_crop = mask_bool[y_min:y_max, x_min:x_max]
493
 
494
- # Force background to white
495
  crop[~mask_crop] = 255
496
 
497
  return Image.fromarray(crop)
498
 
499
  else:
500
- # Detection only → use bounding box with padding
501
  xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int)
502
  x1, y1, x2, y2 = xyxy
503
  x1 = max(0, x1 - padding)
504
  y1 = max(0, y1 - padding)
505
  x2 = min(image.width, x2 + padding)
506
  y2 = min(image.height, y2 + padding)
507
- return image.crop((x1, y1, x2, y2))
508
-
509
- def process_image(image: Image.Image):
510
- results = region_model(image)
511
- region_result = results[0]
512
-
513
- if region_result.boxes is None or len(region_result.boxes) == 0:
514
- return "No text regions detected."
515
 
516
- # Collect regions with their vertical position for sorting
517
- regions_with_pos = []
518
- for i in range(len(region_result.boxes)):
519
- y1 = region_result.boxes.xyxy[i][1].item() # top y-coordinate
520
- crop = get_crop(image, region_result, i, padding=20)
521
- if crop:
522
- regions_with_pos.append((y1, crop))
523
-
524
- # Sort regions top → bottom
525
- regions_with_pos.sort(key=lambda x: x[0])
526
-
527
- full_text_parts = []
528
-
529
- for _, region_crop in regions_with_pos:
530
- line_results = line_model(region_crop)
531
- line_result = line_results[0]
532
-
533
- if line_result.boxes is None or len(line_result.boxes) == 0:
534
- continue
535
-
536
- lines_with_pos = []
537
- for j in range(len(line_result.boxes)):
538
- rel_y1 = line_result.boxes.xyxy[j][1].item() # relative to region crop
539
- rel_x1 = line_result.boxes.xyxy[j][0].item()
540
- line_crop = get_crop(region_crop, line_result, j, padding=15)
541
-
542
- if line_crop is None:
543
- continue
544
-
545
- # TrOCR recognition
546
- pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device)
547
- generated_ids = model.generate(pixel_values)
548
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
549
-
550
- lines_with_pos.append((rel_y1, rel_x1, text))
551
-
552
- # Sort lines: top→bottom, then left→right (handles multi-column reasonably)
553
- lines_with_pos.sort(key=lambda x: (x[0], x[1]))
554
- region_text = "\n".join([item[2] for item in lines_with_pos])
555
- full_text_parts.append(region_text)
556
-
557
- return "\n\n".join(full_text_parts) if full_text_parts else "No text recognized."
558
-
559
- # Gradio interface
560
- demo = gr.Interface(
561
- fn=process_image,
562
- inputs=gr.Image(type="pil", label="Upload handwritten document"),
563
- outputs=gr.Textbox(label="Recognized Text"),
564
- title="Handwritten Text Recognition (YOLO regions/lines + TrOCR)",
565
- description="Uses your local regions.pt and lines.pt (same as Riksarkivet demo) with precise mask-based cropping.",
566
- flagging_mode="never"
567
- )
568
-
569
- if __name__ == "__main__":
570
- demo.launch()# app.py (fixed version)
571
- import gradio as gr
572
- from ultralytics import YOLO
573
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
574
- from PIL import Image
575
- import torch
576
- import numpy as np
577
-
578
- # Load local models
579
- region_model = YOLO("regions.pt")
580
- line_model = YOLO("lines.pt")
581
-
582
- # TrOCR
583
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
584
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
585
-
586
- # Move to GPU if available
587
- device = "cuda" if torch.cuda.is_available() else "cpu"
588
- model.to(device)
589
-
590
- def get_crop(image: Image.Image, result, idx: int, padding: int = 15):
591
- img_np = np.array(image)
592
-
593
- if result.masks is not None:
594
- mask = result.masks.data[idx].cpu().numpy()
595
- mask_bool = mask > 0.5
596
-
597
- ys, xs = np.where(mask_bool)
598
- if len(ys) == 0:
599
  return None
600
 
601
- y_min, y_max = ys.min(), ys.max()
602
- x_min, x_max = xs.min(), xs.max()
603
-
604
- y_min = max(0, y_min - padding)
605
- y_max = min(img_np.shape[0], y_max + padding + 1)
606
- x_min = max(0, x_min - padding)
607
- x_max = min(img_np.shape[1], x_max + padding + 1)
608
-
609
- crop = img_np[y_min:y_max, x_min:x_max]
610
- mask_crop = mask_bool[y_min:y_max, x_min:x_max]
611
-
612
- crop[~mask_crop] = 255
613
-
614
- return Image.fromarray(crop)
615
-
616
- else:
617
- xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int)
618
- x1, y1, x2, y2 = xyxy
619
- x1 = max(0, x1 - padding)
620
- y1 = max(0, y1 - padding)
621
- x2 = min(image.width, x2 + padding)
622
- y2 = min(image.height, y2 + padding)
623
  return image.crop((x1, y1, x2, y2))
624
 
625
  def process_image(image: Image.Image):
 
 
 
626
  results = region_model(image)
627
  region_result = results[0]
628
 
@@ -633,14 +517,17 @@ def process_image(image: Image.Image):
633
  for i in range(len(region_result.boxes)):
634
  y1 = region_result.boxes.xyxy[i][1].item()
635
  crop = get_crop(image, region_result, i, padding=20)
636
- if crop:
637
  regions_with_pos.append((y1, crop))
638
 
 
 
 
639
  regions_with_pos.sort(key=lambda x: x[0])
640
 
641
  full_text_parts = []
642
 
643
- for _, region_crop in regions_with_pos:
644
  line_results = line_model(region_crop)
645
  line_result = line_results[0]
646
 
@@ -653,29 +540,40 @@ def process_image(image: Image.Image):
653
  rel_x1 = line_result.boxes.xyxy[j][0].item()
654
  line_crop = get_crop(region_crop, line_result, j, padding=15)
655
 
656
- if line_crop is None:
 
 
657
  continue
658
 
659
- pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device)
660
- generated_ids = model.generate(pixel_values)
661
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
662
-
663
- lines_with_pos.append((rel_y1, rel_x1, text))
 
 
 
 
 
664
 
665
  lines_with_pos.sort(key=lambda x: (x[0], x[1]))
666
- region_text = "\n".join([item[2] for item in lines_with_pos])
667
- full_text_parts.append(region_text)
 
668
 
669
- return "\n\n".join(full_text_parts) if full_text_parts else "No text recognized."
 
670
 
671
- # Gradio interface (fixed: use flagging_mode instead of allow_flagging)
 
 
672
  demo = gr.Interface(
673
  fn=process_image,
674
  inputs=gr.Image(type="pil", label="Upload handwritten document"),
675
  outputs=gr.Textbox(label="Recognized Text"),
676
- title="Handwritten Text Recognition (YOLO regions/lines + TrOCR)",
677
- description="Uses your local regions.pt and lines.pt (same as Riksarkivet demo) with precise mask-based cropping.",
678
- flagging_mode="never" # ← fixed: changed from allow_flagging to flagging_mode
679
  )
680
 
681
  if __name__ == "__main__":
 
441
 
442
 
443
 
444
+ # app.py - FIXED VERSION with empty crop protection
 
445
  import gradio as gr
446
  from ultralytics import YOLO
447
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
449
  import torch
450
  import numpy as np
451
 
452
+ # Load models
453
+ region_model = YOLO("regions.pt")
454
+ line_model = YOLO("lines.pt")
455
 
 
456
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
457
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
458
 
 
459
  device = "cuda" if torch.cuda.is_available() else "cpu"
460
  model.to(device)
461
 
462
  def get_crop(image: Image.Image, result, idx: int, padding: int = 15):
 
 
 
 
 
463
  img_np = np.array(image)
464
 
465
  if result.masks is not None:
 
466
  mask = result.masks.data[idx].cpu().numpy()
467
  mask_bool = mask > 0.5
468
 
 
473
  y_min, y_max = ys.min(), ys.max()
474
  x_min, x_max = xs.min(), xs.max()
475
 
 
476
  y_min = max(0, y_min - padding)
477
  y_max = min(img_np.shape[0], y_max + padding + 1)
478
  x_min = max(0, x_min - padding)
479
  x_max = min(img_np.shape[1], x_max + padding + 1)
480
 
481
+ # Safety: if after padding still degenerate
482
+ if y_max <= y_min or x_max <= x_min:
483
+ return None
484
+
485
  crop = img_np[y_min:y_max, x_min:x_max]
486
  mask_crop = mask_bool[y_min:y_max, x_min:x_max]
487
 
 
488
  crop[~mask_crop] = 255
489
 
490
  return Image.fromarray(crop)
491
 
492
  else:
493
+ # Bounding box fallback
494
  xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int)
495
  x1, y1, x2, y2 = xyxy
496
  x1 = max(0, x1 - padding)
497
  y1 = max(0, y1 - padding)
498
  x2 = min(image.width, x2 + padding)
499
  y2 = min(image.height, y2 + padding)
 
 
 
 
 
 
 
 
500
 
501
+ if x2 <= x1 or y2 <= y1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  return None
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  return image.crop((x1, y1, x2, y2))
505
 
506
  def process_image(image: Image.Image):
507
+ if image is None:
508
+ return "Please upload an image."
509
+
510
  results = region_model(image)
511
  region_result = results[0]
512
 
 
517
  for i in range(len(region_result.boxes)):
518
  y1 = region_result.boxes.xyxy[i][1].item()
519
  crop = get_crop(image, region_result, i, padding=20)
520
+ if crop and crop.size[0] > 0 and crop.size[1] > 0:
521
  regions_with_pos.append((y1, crop))
522
 
523
+ if not regions_with_pos:
524
+ return "No valid text regions after cropping."
525
+
526
  regions_with_pos.sort(key=lambda x: x[0])
527
 
528
  full_text_parts = []
529
 
530
+ for region_idx, (_, region_crop) in enumerate(regions_with_pos):
531
  line_results = line_model(region_crop)
532
  line_result = line_results[0]
533
 
 
540
  rel_x1 = line_result.boxes.xyxy[j][0].item()
541
  line_crop = get_crop(region_crop, line_result, j, padding=15)
542
 
543
+ if line_crop is None or line_crop.size[0] < 10 or line_crop.size[1] < 8:
544
+ # Skip tiny/invalid crops to prevent TrOCR crash
545
+ # print(f"Skipped tiny line {j} in region {region_idx}")
546
  continue
547
 
548
+ try:
549
+ pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device)
550
+ generated_ids = model.generate(pixel_values)
551
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
552
+ if text: # only add non-empty
553
+ lines_with_pos.append((rel_y1, rel_x1, text))
554
+ except Exception as e:
555
+ # Catch any remaining processing errors
556
+ # print(f"TrOCR error on line {j}: {e}")
557
+ continue
558
 
559
  lines_with_pos.sort(key=lambda x: (x[0], x[1]))
560
+ region_text = "\n".join([item[2] for item in lines_with_pos if item[2]])
561
+ if region_text:
562
+ full_text_parts.append(region_text)
563
 
564
+ if not full_text_parts:
565
+ return "No readable text recognized (possibly due to small/tiny lines or model limitations). Try a clearer document or larger padding."
566
 
567
+ return "\n\n".join(full_text_parts)
568
+
569
+ # Gradio interface
570
  demo = gr.Interface(
571
  fn=process_image,
572
  inputs=gr.Image(type="pil", label="Upload handwritten document"),
573
  outputs=gr.Textbox(label="Recognized Text"),
574
+ title="Handwritten Text Recognition (YOLO + TrOCR)",
575
+ description="Local models: regions.pt / lines.pt + microsoft/trocr-base-handwritten. Mask-based cropping + safeguards against empty crops.",
576
+ flagging_mode="never"
577
  )
578
 
579
  if __name__ == "__main__":