iammraat commited on
Commit
116621e
·
verified ·
1 Parent(s): 227f1ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -546
app.py CHANGED
@@ -260,607 +260,335 @@
260
 
261
 
262
 
263
- # import gradio as gr
264
- # from ultralytics import YOLO
265
- # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
266
- # from PIL import Image, ImageDraw, ImageFont
267
- # import torch
268
- # import logging
269
- # from datetime import datetime
270
- # import os
271
- # import warnings
272
- # import time
273
 
274
- # # Suppress progress bar and unnecessary logs
275
- # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
276
- # warnings.filterwarnings('ignore')
277
- # logging.getLogger('transformers').setLevel(logging.ERROR)
278
- # logging.getLogger('ultralytics').setLevel(logging.ERROR)
279
 
280
- # # Setup logging
281
- # logging.basicConfig(
282
- # level=logging.INFO,
283
- # format='%(asctime)s - %(levelname)s - %(message)s'
284
- # )
285
- # logger = logging.getLogger(__name__)
286
 
287
- # logger.info("Starting model loading...")
288
- # device = "cuda" if torch.cuda.is_available() else "cpu"
289
- # logger.info(f"Using device: {device}")
290
 
291
- # # --- ROBUST MODEL LOADING FUNCTION ---
292
- # def load_model_with_retry(model_class, model_name, token=None, retries=5, delay=5):
293
- # """Attempts to load a HF model with retries to handle network timeouts."""
294
- # for attempt in range(retries):
295
- # try:
296
- # logger.info(f"Loading {model_name} (Attempt {attempt + 1}/{retries})...")
297
- # if "Processor" in str(model_class):
298
- # return model_class.from_pretrained(model_name, token=token)
299
- # else:
300
- # return model_class.from_pretrained(model_name, token=token).to(device)
301
- # except Exception as e:
302
- # logger.warning(f"Failed to load {model_name}: {e}")
303
- # if attempt < retries - 1:
304
- # logger.info(f"Retrying in {delay} seconds...")
305
- # time.sleep(delay)
306
- # else:
307
- # logger.error(f"Given up on loading {model_name} after {retries} attempts.")
308
- # raise e
309
 
310
- # try:
311
- # # 1. Load YOLO Models (Local Files)
312
- # region_model_file = 'regions.pt'
313
- # line_model_file = 'lines.pt'
314
 
315
- # # Simple check for local files
316
- # if not os.path.exists(region_model_file):
317
- # for file in os.listdir('.'):
318
- # if 'region' in file.lower() and file.endswith('.pt'): region_model_file = file
319
- # elif 'line' in file.lower() and file.endswith('.pt'): line_model_file = file
320
 
321
- # if not os.path.exists(region_model_file) or not os.path.exists(line_model_file):
322
- # raise FileNotFoundError("YOLO .pt files (regions.pt/lines.pt) not found.")
323
 
324
- # logger.info("Loading YOLO models...")
325
- # region_model = YOLO(region_model_file)
326
- # line_model = YOLO(line_model_file)
327
- # logger.info("✓ YOLO models loaded")
328
 
329
- # # 2. Load TrOCR with Retries
330
- # hf_token = os.getenv("HF_TOKEN")
331
 
332
- # processor = load_model_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", token=hf_token)
333
- # logger.info("✓ TrOCR processor loaded")
334
 
335
- # trocr_model = load_model_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", token=hf_token)
336
- # logger.info("✓ TrOCR model loaded")
337
 
338
- # logger.info("All models loaded successfully!")
339
 
340
- # except Exception as e:
341
- # logger.error(f"CRITICAL ERROR loading models: {str(e)}")
342
- # raise
343
-
344
- # # --- OCR HELPER ---
345
- # def run_trocr(image_slice, processor, model, device):
346
- # """Runs TrOCR on a single cropped image slice."""
347
- # pixel_values = processor(images=image_slice, return_tensors="pt").pixel_values.to(device)
348
- # generated_ids = model.generate(pixel_values)
349
- # return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
350
 
351
- # def process_document(image, enable_debug_crops=False):
352
- # """Process uploaded document image and extract handwritten text with visualization."""
353
- # timestamp = datetime.now().strftime("%H:%M:%S")
354
- # log_output = []
 
 
 
 
 
 
 
355
 
356
- # def add_log(message, level="INFO"):
357
- # log_msg = f"[{timestamp}] {level}: {message}"
358
- # log_output.append(log_msg)
359
- # if level == "ERROR":
360
- # logger.error(message)
361
- # else:
362
- # logger.info(message)
363
 
364
- # add_log("Starting document processing")
365
 
366
- # if image is None:
367
- # add_log("No image provided", "ERROR")
368
- # return None, "Please upload an image", "\n".join(log_output)
369
 
370
- # try:
371
- # # Prepare Image
372
- # if not isinstance(image, Image.Image):
373
- # img = Image.open(image).convert("RGB")
374
- # else:
375
- # img = image.convert("RGB")
376
 
377
- # # Create a drawing context for the debug image
378
- # debug_img = img.copy()
379
- # draw = ImageDraw.Draw(debug_img)
380
 
381
- # width, height = img.size
382
- # add_log(f"Image size: {width}x{height} pixels")
383
 
384
- # all_lines = []
385
- # debug_crops_dir = "debug_crops"
386
 
387
- # if enable_debug_crops:
388
- # os.makedirs(debug_crops_dir, exist_ok=True)
389
- # add_log(f"Debug crops will be saved to {debug_crops_dir}/")
390
 
391
- # # --- STRATEGY 1: Region Detection ---
392
- # add_log("Strategy 1: Running region detection...")
393
- # region_results = region_model(img, conf=0.2, imgsz=1024, verbose=False)
394
- # regions = region_results[0].boxes
395
- # num_regions = len(regions)
396
- # add_log(f"✓ Found {num_regions} potential text region(s)")
397
 
398
- # found_lines_in_regions = False
399
 
400
- # if num_regions > 0:
401
- # for region_idx, region in enumerate(regions):
402
- # add_log(f"Processing region {region_idx + 1}/{num_regions}")
403
 
404
- # # FIX 1: Use round() instead of int() to minimize precision loss
405
- # rx1, ry1, rx2, ry2 = map(round, region.xyxy[0].tolist())
406
 
407
- # # Calculate region dimensions
408
- # region_width = rx2 - rx1
409
- # region_height = ry2 - ry1
410
 
411
- # add_log(f" Region coords: ({rx1}, {ry1}) → ({rx2}, {ry2}), size: {region_width}x{region_height}")
412
 
413
- # # Filter small artifacts
414
- # if region_width < 50 or region_height < 50:
415
- # add_log(f" Skipping tiny artifact: {region_width}x{region_height} px")
416
- # continue
417
 
418
- # # FIX 2: Add padding to region crops to avoid edge effects
419
- # padding = 10
420
- # padded_rx1 = max(0, rx1 - padding)
421
- # padded_ry1 = max(0, ry1 - padding)
422
- # padded_rx2 = min(width, rx2 + padding)
423
- # padded_ry2 = min(height, ry2 + padding)
424
 
425
- # add_log(f" Padded coords: ({padded_rx1}, {padded_ry1}) → ({padded_rx2}, {padded_ry2})")
426
 
427
- # # Draw GREEN box for Region (original bounds, not padded)
428
- # draw.rectangle([rx1, ry1, rx2, ry2], outline="green", width=5)
429
 
430
- # # Crop Region with padding
431
- # region_crop = img.crop((padded_rx1, padded_ry1, padded_rx2, padded_ry2))
432
 
433
- # if enable_debug_crops:
434
- # region_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}.png")
435
 
436
- # # Detect lines in this region
437
- # add_log(f" Running line detection on region crop ({region_crop.size[0]}x{region_crop.size[1]})...")
438
- # line_results = line_model(region_crop, conf=0.2, imgsz=1024, verbose=False)
439
- # lines_data = line_results[0].boxes.xyxy.cpu().numpy()
440
- # num_lines = len(lines_data)
441
- # add_log(f" ✓ Found {num_lines} line(s) in region")
442
 
443
- # if num_lines > 0:
444
- # found_lines_in_regions = True
445
 
446
- # # Sort lines by Y position (index 1 of xyxy)
447
- # sorted_indices = lines_data[:, 1].argsort()
448
 
449
- # for line_idx, idx in enumerate(sorted_indices):
450
- # # FIX 3: Use round() for line coordinates too
451
- # lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
452
 
453
- # line_width = lx2 - lx1
454
- # line_height = ly2 - ly1
455
 
456
- # add_log(f" Line {line_idx + 1} (local coords): ({lx1}, {ly1}) → ({lx2}, {ly2}), size: {line_width}x{line_height}")
457
 
458
- # # FIX 4: Translate line coordinates back to original image space
459
- # # Account for padding offset
460
- # global_lx1 = padded_rx1 + lx1
461
- # global_ly1 = padded_ry1 + ly1
462
- # global_lx2 = padded_rx1 + lx2
463
- # global_ly2 = padded_ry1 + ly2
464
 
465
- # # FIX 5: Validate coordinates are within image bounds
466
- # global_lx1 = max(0, min(width, global_lx1))
467
- # global_ly1 = max(0, min(height, global_ly1))
468
- # global_lx2 = max(0, min(width, global_lx2))
469
- # global_ly2 = max(0, min(height, global_ly2))
470
 
471
- # add_log(f" Line {line_idx + 1} (global coords): ({global_lx1}, {global_ly1}) → ({global_lx2}, {global_ly2})")
472
 
473
- # # Draw RED box for Line
474
- # draw.rectangle([global_lx1, global_ly1, global_lx2, global_ly2], outline="red", width=3)
475
 
476
- # # OCR on the line crop from region_crop
477
- # line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
478
 
479
- # if enable_debug_crops:
480
- # line_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}_line_{line_idx:02d}.png")
481
 
482
- # text = run_trocr(line_crop, processor, trocr_model, device)
483
- # add_log(f" Line {line_idx + 1} OCR: '{text}'")
484
- # all_lines.append(text)
485
-
486
- # # --- STRATEGY 2: Fallback to Full Page ---
487
- # if not found_lines_in_regions:
488
- # add_log("⚠️ Region detection yielded no lines. Switching to Fallback Strategy...", "WARNING")
489
- # add_log("Strategy 2: Running line detection on full page")
490
 
491
- # line_results = line_model(img, conf=0.2, imgsz=1024, verbose=False)
492
- # lines_data = line_results[0].boxes.xyxy.cpu().numpy()
493
- # num_lines = len(lines_data)
494
- # add_log(f"✓ Fallback found {num_lines} line(s) on full page")
495
 
496
- # if num_lines > 0:
497
- # sorted_indices = lines_data[:, 1].argsort()
498
 
499
- # for line_idx, idx in enumerate(sorted_indices):
500
- # # FIX 6: Use round() consistently
501
- # lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
502
 
503
- # line_width = lx2 - lx1
504
- # line_height = ly2 - ly1
505
 
506
- # add_log(f" Fallback Line {line_idx + 1}: ({lx1}, {ly1}) → ({lx2}, {ly2}), size: {line_width}x{line_height}")
507
 
508
- # # FIX 7: Validate coordinates
509
- # lx1 = max(0, min(width, lx1))
510
- # ly1 = max(0, min(height, ly1))
511
- # lx2 = max(0, min(width, lx2))
512
- # ly2 = max(0, min(height, ly2))
513
 
514
- # # Draw RED box for Line (on full image)
515
- # draw.rectangle([lx1, ly1, lx2, ly2], outline="red", width=3)
516
 
517
- # line_crop = img.crop((lx1, ly1, lx2, ly2))
518
 
519
- # if enable_debug_crops:
520
- # line_crop.save(f"{debug_crops_dir}/fullpage_line_{line_idx:02d}.png")
521
 
522
- # text = run_trocr(line_crop, processor, trocr_model, device)
523
- # add_log(f" Fallback Line {line_idx + 1} OCR: '{text}'")
524
- # all_lines.append(text)
525
 
526
- # if not all_lines:
527
- # add_log("Failed to detect any text lines in both strategies", "ERROR")
528
- # return debug_img, "No text could be extracted.", "\n".join(log_output)
529
 
530
- # add_log(f"✓ Success! Extracted {len(all_lines)} total line(s)")
531
 
532
- # if enable_debug_crops:
533
- # add_log(f"✓ Debug crops saved to {debug_crops_dir}/")
534
 
535
- # final_text = '\n'.join(all_lines)
536
 
537
- # return debug_img, final_text, "\n".join(log_output)
538
 
539
- # except Exception as e:
540
- # error_msg = f"Error processing image: {str(e)}"
541
- # add_log(error_msg, "ERROR")
542
- # logger.exception("Full error traceback:")
543
- # return image, f"Error: {str(e)}", "\n".join(log_output)
544
 
545
- # # Create Gradio interface
546
- # demo = gr.Interface(
547
- # fn=process_document,
548
- # inputs=[
549
- # gr.Image(type="pil", label="Upload Handwritten Document"),
550
- # gr.Checkbox(label="Save debug crops to disk", value=False)
551
- # ],
552
- # outputs=[
553
- # gr.Image(type="pil", label="Debug Visualization (Green=Region, Red=Lines)"),
554
- # gr.Textbox(label="Extracted Text", lines=10),
555
- # gr.Textbox(label="Processing Logs", lines=15)
556
- # ],
557
- # title="📝 Handwritten Text Recognition (HTR) with Enhanced Debugging",
558
- # description="""
559
- # Upload an image of a handwritten document.
560
 
561
- # **Visualization Key:**
562
- # - 🟩 **Green Box:** The broad region identified as containing text (original bounds).
563
- # - 🟥 **Red Box:** The specific line of text sent to the OCR engine (with coordinate validation).
564
 
565
- # **Improvements:**
566
- # - Fixed coordinate rounding (eliminates truncation errors)
567
- # - Added 10px padding to region crops (reduces edge effects)
568
- # - Coordinate validation (ensures all boxes are within image bounds)
569
- # - Enhanced logging with detailed coordinate tracking
570
- # - Optional debug crop saving
571
- # """,
572
- # flagging_mode="never",
573
- # theme=gr.themes.Soft()
574
- # )
575
-
576
- # if __name__ == "__main__":
577
- # logger.info("Launching Gradio interface...")
578
- # demo.launch()
579
-
580
-
581
-
582
-
583
-
584
-
585
-
586
-
587
-
588
-
589
-
590
-
591
-
592
-
593
-
594
-
595
-
596
-
597
-
598
-
599
-
600
-
601
- import gradio as gr
602
- from ultralytics import YOLO
603
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
604
- from PIL import Image, ImageDraw
605
- import torch
606
- import logging
607
- import os
608
- import warnings
609
- import time
610
- from datetime import datetime
611
-
612
- # Suppress noisy logs
613
- os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
614
- warnings.filterwarnings('ignore')
615
- logging.getLogger('transformers').setLevel(logging.ERROR)
616
- logging.getLogger('ultralytics').setLevel(logging.WARNING) # still allow important warnings
617
-
618
- # Setup clean logging
619
- logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s')
620
- logger = logging.getLogger(__name__)
621
-
622
- logger.info("Initializing models...")
623
- device = "cuda" if torch.cuda.is_available() else "cpu"
624
- logger.info(f"Device: {device}")
625
-
626
- def load_with_retry(cls, name, token=None, retries=4, delay=6):
627
- for attempt in range(1, retries + 1):
628
- try:
629
- logger.info(f"Loading {name} (attempt {attempt}/{retries})")
630
- if "Processor" in str(cls):
631
- return cls.from_pretrained(name, token=token)
632
- return cls.from_pretrained(name, token=token).to(device)
633
- except Exception as e:
634
- logger.warning(f"Load failed: {e}")
635
- if attempt < retries:
636
- time.sleep(delay)
637
- raise RuntimeError(f"Failed to load {name} after {retries} attempts")
638
-
639
- try:
640
- # Locate local YOLO weights
641
- region_pt = 'regions.pt'
642
- line_pt = 'lines.pt'
643
-
644
- if not os.path.exists(region_pt):
645
- for f in os.listdir('.'):
646
- name = f.lower()
647
- if 'region' in name and name.endswith('.pt'): region_pt = f
648
- if 'line' in name and name.endswith('.pt'): line_pt = f
649
-
650
- if not all(os.path.exists(p) for p in [region_pt, line_pt]):
651
- raise FileNotFoundError("Could not find regions.pt and lines.pt (or similar)")
652
-
653
- logger.info("Loading YOLO models...")
654
- region_model = YOLO(region_pt)
655
- line_model = YOLO(line_pt)
656
- logger.info("YOLO models loaded")
657
-
658
- hf_token = os.getenv("HF_TOKEN")
659
- processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token)
660
- trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token)
661
- logger.info("TrOCR loaded → ready")
662
-
663
- except Exception as e:
664
- logger.error(f"Model loading failed: {e}", exc_info=True)
665
- raise
666
-
667
-
668
-
669
-
670
-
671
- def run_ocr(crop: Image.Image) -> str:
672
- if crop.width < 20 or crop.height < 12:
673
- return ""
674
- pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device)
675
- ids = trocr.generate(pixels, max_new_tokens=128)
676
- return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
677
-
678
-
679
- def process_document(
680
- image,
681
- enable_debug_crops: bool = False,
682
- region_imgsz: int = 1024,
683
- line_imgsz_base: int = 768,
684
- conf_thresh: float = 0.25,
685
- ):
686
- start_ts = datetime.now().strftime("%H:%M:%S")
687
- logs = []
688
-
689
- def log(msg: str, level: str = "INFO"):
690
- line = f"[{start_ts}] {level:5} {msg}"
691
- logs.append(line)
692
- if level == "ERROR":
693
- logger.error(msg)
694
- else:
695
- logger.info(msg)
696
-
697
- log("Start processing")
698
-
699
- if image is None:
700
- log("No image uploaded", "ERROR")
701
- return None, "Upload an image", "\n".join(logs)
702
-
703
- try:
704
- # ── Prepare ─────────────────────────────────────────────────────────────
705
- if not isinstance(image, Image.Image):
706
- img = Image.open(image).convert("RGB")
707
- else:
708
- img = image.convert("RGB")
709
-
710
- debug_img = img.copy()
711
- draw = ImageDraw.Draw(debug_img)
712
- w, h = img.size
713
- log(f"Input image: {w} × {h} px")
714
-
715
- debug_dir = "debug_crops"
716
- if enable_debug_crops:
717
- os.makedirs(debug_dir, exist_ok=True)
718
- log(f"Debug crops → {debug_dir}/")
719
-
720
- extracted = []
721
- used_fallback = False
722
-
723
- # ── Strategy 1: Region → Lines ──────────────────────────────────────────
724
- log(f"Running region detection (imgsz={region_imgsz}) …")
725
- res_region = region_model(img, conf=conf_thresh, imgsz=region_imgsz, verbose=False)[0]
726
- boxes_region = res_region.boxes
727
-
728
- log(f"→ {len(boxes_region)} region candidate(s) (conf ≥ {conf_thresh})")
729
-
730
- found_any_line = False
731
-
732
- for i, box in enumerate(boxes_region, 1):
733
- conf = float(box.conf)
734
- xyxy = box.xyxy[0].cpu().tolist()
735
- rx1, ry1, rx2, ry2 = map(round, xyxy)
736
-
737
- rw, rh = rx2 - rx1, ry2 - ry1
738
- log(f"Region {i}/{len(boxes_region)} conf={conf:.3f} {rx1},{ry1} → {rx2},{ry2} ({rw}×{rh})")
739
-
740
- if rw < 60 or rh < 40:
741
- log(f" → skipped (too small)")
742
- continue
743
-
744
- # Padding
745
- pad = 12
746
- px1 = max(0, rx1 - pad)
747
- py1 = max(0, ry1 - pad)
748
- px2 = min(w, rx2 + pad)
749
- py2 = min(h, ry2 + pad)
750
-
751
- log(f" Padded crop: {px1},{py1} → {px2},{py2}")
752
-
753
- draw.rectangle((rx1, ry1, rx2, ry2), outline="green", width=4)
754
-
755
- crop_region = img.crop((px1, py1, px2, py2))
756
- crop_w, crop_h = crop_region.size
757
-
758
- if enable_debug_crops:
759
- crop_region.save(f"{debug_dir}/region_{i:02d}.png")
760
-
761
- # Adaptive line imgsz: bigger crops → bigger inference size
762
- line_sz = line_imgsz_base
763
- if max(crop_w, crop_h) > 1400:
764
- line_sz = 1280
765
- elif max(crop_w, crop_h) < 400:
766
- line_sz = 640
767
-
768
- log(f" → line detection (imgsz={line_sz}) on {crop_w}×{crop_h} crop …")
769
- res_line = line_model(crop_region, conf=conf_thresh, imgsz=line_sz, verbose=False)[0]
770
- line_boxes = res_line.boxes
771
-
772
- log(f" → {len(line_boxes)} line candidate(s)")
773
-
774
- if len(line_boxes) == 0:
775
- continue
776
-
777
- found_any_line = True
778
-
779
- # Sort top → bottom
780
- ys = line_boxes.xyxy[:, 1].cpu().numpy()
781
- order = ys.argsort()
782
-
783
- for j, idx in enumerate(order, 1):
784
- conf_line = float(line_boxes.conf[idx])
785
- lx1, ly1, lx2, ly2 = map(round, line_boxes.xyxy[idx].cpu().tolist())
786
-
787
- lw, lh = lx2 - lx1, ly2 - ly1
788
- log(f" Line {j} conf={conf_line:.3f} local {lx1},{ly1} → {lx2},{ly2} ({lw}×{lh})")
789
-
790
- # Back to global coordinates
791
- gx1 = px1 + lx1
792
- gy1 = py1 + ly1
793
- gx2 = px1 + lx2
794
- gy2 = py1 + ly2
795
-
796
- # Safety clamp
797
- gx1, gy1 = max(0, gx1), max(0, gy1)
798
- gx2, gy2 = min(w, gx2), min(h, gy2)
799
-
800
- log(f" → global {gx1},{gy1} → {gx2},{gy2}")
801
-
802
- draw.rectangle((gx1, gy1, gx2, gy2), outline="red", width=3)
803
-
804
- line_crop = crop_region.crop((lx1, ly1, lx2, ly2))
805
-
806
- if enable_debug_crops:
807
- line_crop.save(f"{debug_dir}/reg{i:02d}_line{j:02d}_conf{conf_line:.2f}.png")
808
-
809
- text = run_ocr(line_crop)
810
- log(f" OCR → '{text}'")
811
- if text:
812
- extracted.append(text)
813
 
814
- # ── Strategy 2: Fallback full-page line detection ───────────────────────
815
- if not found_any_line:
816
- used_fallback = True
817
- log("No lines found in regions → fallback: full-page line detection")
818
 
819
- line_sz = 1024 if max(w, h) > 1800 else line_imgsz_base
820
- log(f"Full-page line detection (imgsz={line_sz}) …")
821
 
822
- res = line_model(img, conf=conf_thresh, imgsz=line_sz, verbose=False)[0]
823
- boxes = res.boxes
824
 
825
- log(f"→ {len(boxes)} line(s) on full page")
826
 
827
- if len(boxes) > 0:
828
- ys = boxes.xyxy[:, 1].cpu().numpy()
829
- order = ys.argsort()
830
 
831
- for j, idx in enumerate(order, 1):
832
- conf = float(boxes.conf[idx])
833
- x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist())
834
- log(f" Line {j} conf={conf:.3f} {x1},{y1} → {x2},{y2}")
835
 
836
- draw.rectangle((x1,y1,x2,y2), outline="red", width=3)
837
 
838
- crop = img.crop((x1,y1,x2,y2))
839
 
840
- if enable_debug_crops:
841
- crop.save(f"{debug_dir}/fallback_line{j:02d}_conf{conf:.2f}.png")
842
 
843
- text = run_ocr(crop)
844
- log(f" OCR → '{text}'")
845
- if text:
846
- extracted.append(text)
847
 
848
- # ── Finalize ────────────────────────────────────────────────────────────
849
- if not extracted:
850
- msg = "No readable text lines detected in either strategy"
851
- log(msg, "WARNING")
852
- return debug_img, msg, "\n".join(logs)
853
 
854
- log(f"Success — extracted {len(extracted)} line(s)")
855
- if enable_debug_crops:
856
- log(f"Debug crops saved to {debug_dir}/")
857
 
858
- return debug_img, "\n".join(extracted), "\n".join(logs)
859
 
860
- except Exception as e:
861
- log(f"Processing failed: {e}", "ERROR")
862
- logger.exception("Traceback:")
863
- return debug_img, f"Error: {str(e)}", "\n".join(logs)
864
 
865
 
866
 
@@ -868,29 +596,4 @@ def process_document(
868
 
869
 
870
 
871
- demo = gr.Interface(
872
- fn=process_document,
873
- inputs=[
874
- gr.Image(type="pil", label="Handwritten document"),
875
- gr.Checkbox(label="Save debug crops", value=False),
876
- gr.Slider(640, 1600, step=64, value=1024, label="Region detection size (imgsz)"),
877
- gr.Slider(512, 1280, step=64, value=768, label="Base line detection size"),
878
- gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"),
879
- ],
880
- outputs=[
881
- gr.Image(label="Debug (green=region, red=line)"),
882
- gr.Textbox(label="Extracted Text", lines=10),
883
- gr.Textbox(label="Detailed Logs (copy these if boxes look wrong)", lines=18),
884
- ],
885
- title="Handwritten Text → OCR + Debug",
886
- description=(
887
- "Green = detected text regions • Red = individual text lines sent to TrOCR\n\n"
888
- "Copy the **Detailed Logs** if alignment still looks off — especially coords, sizes & confidences."
889
- ),
890
- theme=gr.themes.Soft(),
891
- flagging_mode="never",
892
- )
893
 
894
- if __name__ == "__main__":
895
- logger.info("Launching interface…")
896
- demo.launch()
 
260
 
261
 
262
 
263
+ import gradio as gr
264
+ from ultralytics import YOLO
265
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
266
+ from PIL import Image, ImageDraw, ImageFont
267
+ import torch
268
+ import logging
269
+ from datetime import datetime
270
+ import os
271
+ import warnings
272
+ import time
273
 
274
+ # Suppress progress bar and unnecessary logs
275
+ os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
276
+ warnings.filterwarnings('ignore')
277
+ logging.getLogger('transformers').setLevel(logging.ERROR)
278
+ logging.getLogger('ultralytics').setLevel(logging.ERROR)
279
 
280
+ # Setup logging
281
+ logging.basicConfig(
282
+ level=logging.INFO,
283
+ format='%(asctime)s - %(levelname)s - %(message)s'
284
+ )
285
+ logger = logging.getLogger(__name__)
286
 
287
+ logger.info("Starting model loading...")
288
+ device = "cuda" if torch.cuda.is_available() else "cpu"
289
+ logger.info(f"Using device: {device}")
290
 
291
+ # --- ROBUST MODEL LOADING FUNCTION ---
292
+ def load_model_with_retry(model_class, model_name, token=None, retries=5, delay=5):
293
+ """Attempts to load a HF model with retries to handle network timeouts."""
294
+ for attempt in range(retries):
295
+ try:
296
+ logger.info(f"Loading {model_name} (Attempt {attempt + 1}/{retries})...")
297
+ if "Processor" in str(model_class):
298
+ return model_class.from_pretrained(model_name, token=token)
299
+ else:
300
+ return model_class.from_pretrained(model_name, token=token).to(device)
301
+ except Exception as e:
302
+ logger.warning(f"Failed to load {model_name}: {e}")
303
+ if attempt < retries - 1:
304
+ logger.info(f"Retrying in {delay} seconds...")
305
+ time.sleep(delay)
306
+ else:
307
+ logger.error(f"Given up on loading {model_name} after {retries} attempts.")
308
+ raise e
309
 
310
+ try:
311
+ # 1. Load YOLO Models (Local Files)
312
+ region_model_file = 'regions.pt'
313
+ line_model_file = 'lines.pt'
314
 
315
+ # Simple check for local files
316
+ if not os.path.exists(region_model_file):
317
+ for file in os.listdir('.'):
318
+ if 'region' in file.lower() and file.endswith('.pt'): region_model_file = file
319
+ elif 'line' in file.lower() and file.endswith('.pt'): line_model_file = file
320
 
321
+ if not os.path.exists(region_model_file) or not os.path.exists(line_model_file):
322
+ raise FileNotFoundError("YOLO .pt files (regions.pt/lines.pt) not found.")
323
 
324
+ logger.info("Loading YOLO models...")
325
+ region_model = YOLO(region_model_file)
326
+ line_model = YOLO(line_model_file)
327
+ logger.info("✓ YOLO models loaded")
328
 
329
+ # 2. Load TrOCR with Retries
330
+ hf_token = os.getenv("HF_TOKEN")
331
 
332
+ processor = load_model_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", token=hf_token)
333
+ logger.info("✓ TrOCR processor loaded")
334
 
335
+ trocr_model = load_model_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", token=hf_token)
336
+ logger.info("✓ TrOCR model loaded")
337
 
338
+ logger.info("All models loaded successfully!")
339
 
340
+ except Exception as e:
341
+ logger.error(f"CRITICAL ERROR loading models: {str(e)}")
342
+ raise
 
 
 
 
 
 
 
343
 
344
+ # --- OCR HELPER ---
345
+ def run_trocr(image_slice, processor, model, device):
346
+ """Runs TrOCR on a single cropped image slice."""
347
+ pixel_values = processor(images=image_slice, return_tensors="pt").pixel_values.to(device)
348
+ generated_ids = model.generate(pixel_values)
349
+ return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
350
+
351
+ def process_document(image, enable_debug_crops=False):
352
+ """Process uploaded document image and extract handwritten text with visualization."""
353
+ timestamp = datetime.now().strftime("%H:%M:%S")
354
+ log_output = []
355
 
356
+ def add_log(message, level="INFO"):
357
+ log_msg = f"[{timestamp}] {level}: {message}"
358
+ log_output.append(log_msg)
359
+ if level == "ERROR":
360
+ logger.error(message)
361
+ else:
362
+ logger.info(message)
363
 
364
+ add_log("Starting document processing")
365
 
366
+ if image is None:
367
+ add_log("No image provided", "ERROR")
368
+ return None, "Please upload an image", "\n".join(log_output)
369
 
370
+ try:
371
+ # Prepare Image
372
+ if not isinstance(image, Image.Image):
373
+ img = Image.open(image).convert("RGB")
374
+ else:
375
+ img = image.convert("RGB")
376
 
377
+ # Create a drawing context for the debug image
378
+ debug_img = img.copy()
379
+ draw = ImageDraw.Draw(debug_img)
380
 
381
+ width, height = img.size
382
+ add_log(f"Image size: {width}x{height} pixels")
383
 
384
+ all_lines = []
385
+ debug_crops_dir = "debug_crops"
386
 
387
+ if enable_debug_crops:
388
+ os.makedirs(debug_crops_dir, exist_ok=True)
389
+ add_log(f"Debug crops will be saved to {debug_crops_dir}/")
390
 
391
+ # --- STRATEGY 1: Region Detection ---
392
+ add_log("Strategy 1: Running region detection...")
393
+ region_results = region_model(img, conf=0.2, imgsz=1024, verbose=False)
394
+ regions = region_results[0].boxes
395
+ num_regions = len(regions)
396
+ add_log(f"✓ Found {num_regions} potential text region(s)")
397
 
398
+ found_lines_in_regions = False
399
 
400
+ if num_regions > 0:
401
+ for region_idx, region in enumerate(regions):
402
+ add_log(f"Processing region {region_idx + 1}/{num_regions}")
403
 
404
+ # FIX 1: Use round() instead of int() to minimize precision loss
405
+ rx1, ry1, rx2, ry2 = map(round, region.xyxy[0].tolist())
406
 
407
+ # Calculate region dimensions
408
+ region_width = rx2 - rx1
409
+ region_height = ry2 - ry1
410
 
411
+ add_log(f" Region coords: ({rx1}, {ry1}) → ({rx2}, {ry2}), size: {region_width}x{region_height}")
412
 
413
+ # Filter small artifacts
414
+ if region_width < 50 or region_height < 50:
415
+ add_log(f" Skipping tiny artifact: {region_width}x{region_height} px")
416
+ continue
417
 
418
+ # FIX 2: Add padding to region crops to avoid edge effects
419
+ padding = 10
420
+ padded_rx1 = max(0, rx1 - padding)
421
+ padded_ry1 = max(0, ry1 - padding)
422
+ padded_rx2 = min(width, rx2 + padding)
423
+ padded_ry2 = min(height, ry2 + padding)
424
 
425
+ add_log(f" Padded coords: ({padded_rx1}, {padded_ry1}) → ({padded_rx2}, {padded_ry2})")
426
 
427
+ # Draw GREEN box for Region (original bounds, not padded)
428
+ draw.rectangle([rx1, ry1, rx2, ry2], outline="green", width=5)
429
 
430
+ # Crop Region with padding
431
+ region_crop = img.crop((padded_rx1, padded_ry1, padded_rx2, padded_ry2))
432
 
433
+ if enable_debug_crops:
434
+ region_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}.png")
435
 
436
+ # Detect lines in this region
437
+ add_log(f" Running line detection on region crop ({region_crop.size[0]}x{region_crop.size[1]})...")
438
+ line_results = line_model(region_crop, conf=0.2, imgsz=1024, verbose=False)
439
+ lines_data = line_results[0].boxes.xyxy.cpu().numpy()
440
+ num_lines = len(lines_data)
441
+ add_log(f" ✓ Found {num_lines} line(s) in region")
442
 
443
+ if num_lines > 0:
444
+ found_lines_in_regions = True
445
 
446
+ # Sort lines by Y position (index 1 of xyxy)
447
+ sorted_indices = lines_data[:, 1].argsort()
448
 
449
+ for line_idx, idx in enumerate(sorted_indices):
450
+ # FIX 3: Use round() for line coordinates too
451
+ lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
452
 
453
+ line_width = lx2 - lx1
454
+ line_height = ly2 - ly1
455
 
456
+ add_log(f" Line {line_idx + 1} (local coords): ({lx1}, {ly1}) → ({lx2}, {ly2}), size: {line_width}x{line_height}")
457
 
458
+ # FIX 4: Translate line coordinates back to original image space
459
+ # Account for padding offset
460
+ global_lx1 = padded_rx1 + lx1
461
+ global_ly1 = padded_ry1 + ly1
462
+ global_lx2 = padded_rx1 + lx2
463
+ global_ly2 = padded_ry1 + ly2
464
 
465
+ # FIX 5: Validate coordinates are within image bounds
466
+ global_lx1 = max(0, min(width, global_lx1))
467
+ global_ly1 = max(0, min(height, global_ly1))
468
+ global_lx2 = max(0, min(width, global_lx2))
469
+ global_ly2 = max(0, min(height, global_ly2))
470
 
471
+ add_log(f" Line {line_idx + 1} (global coords): ({global_lx1}, {global_ly1}) → ({global_lx2}, {global_ly2})")
472
 
473
+ # Draw RED box for Line
474
+ draw.rectangle([global_lx1, global_ly1, global_lx2, global_ly2], outline="red", width=3)
475
 
476
+ # OCR on the line crop from region_crop
477
+ line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
478
 
479
+ if enable_debug_crops:
480
+ line_crop.save(f"{debug_crops_dir}/region_{region_idx:02d}_line_{line_idx:02d}.png")
481
 
482
+ text = run_trocr(line_crop, processor, trocr_model, device)
483
+ add_log(f" Line {line_idx + 1} OCR: '{text}'")
484
+ all_lines.append(text)
485
+
486
+ # --- STRATEGY 2: Fallback to Full Page ---
487
+ if not found_lines_in_regions:
488
+ add_log("⚠️ Region detection yielded no lines. Switching to Fallback Strategy...", "WARNING")
489
+ add_log("Strategy 2: Running line detection on full page")
490
 
491
+ line_results = line_model(img, conf=0.2, imgsz=1024, verbose=False)
492
+ lines_data = line_results[0].boxes.xyxy.cpu().numpy()
493
+ num_lines = len(lines_data)
494
+ add_log(f"✓ Fallback found {num_lines} line(s) on full page")
495
 
496
+ if num_lines > 0:
497
+ sorted_indices = lines_data[:, 1].argsort()
498
 
499
+ for line_idx, idx in enumerate(sorted_indices):
500
+ # FIX 6: Use round() consistently
501
+ lx1, ly1, lx2, ly2 = map(round, lines_data[idx].tolist())
502
 
503
+ line_width = lx2 - lx1
504
+ line_height = ly2 - ly1
505
 
506
+ add_log(f" Fallback Line {line_idx + 1}: ({lx1}, {ly1}) → ({lx2}, {ly2}), size: {line_width}x{line_height}")
507
 
508
+ # FIX 7: Validate coordinates
509
+ lx1 = max(0, min(width, lx1))
510
+ ly1 = max(0, min(height, ly1))
511
+ lx2 = max(0, min(width, lx2))
512
+ ly2 = max(0, min(height, ly2))
513
 
514
+ # Draw RED box for Line (on full image)
515
+ draw.rectangle([lx1, ly1, lx2, ly2], outline="red", width=3)
516
 
517
+ line_crop = img.crop((lx1, ly1, lx2, ly2))
518
 
519
+ if enable_debug_crops:
520
+ line_crop.save(f"{debug_crops_dir}/fullpage_line_{line_idx:02d}.png")
521
 
522
+ text = run_trocr(line_crop, processor, trocr_model, device)
523
+ add_log(f" Fallback Line {line_idx + 1} OCR: '{text}'")
524
+ all_lines.append(text)
525
 
526
+ if not all_lines:
527
+ add_log("Failed to detect any text lines in both strategies", "ERROR")
528
+ return debug_img, "No text could be extracted.", "\n".join(log_output)
529
 
530
+ add_log(f"✓ Success! Extracted {len(all_lines)} total line(s)")
531
 
532
+ if enable_debug_crops:
533
+ add_log(f"✓ Debug crops saved to {debug_crops_dir}/")
534
 
535
+ final_text = '\n'.join(all_lines)
536
 
537
+ return debug_img, final_text, "\n".join(log_output)
538
 
539
+ except Exception as e:
540
+ error_msg = f"Error processing image: {str(e)}"
541
+ add_log(error_msg, "ERROR")
542
+ logger.exception("Full error traceback:")
543
+ return image, f"Error: {str(e)}", "\n".join(log_output)
544
 
545
+ # Create Gradio interface
546
+ demo = gr.Interface(
547
+ fn=process_document,
548
+ inputs=[
549
+ gr.Image(type="pil", label="Upload Handwritten Document"),
550
+ gr.Checkbox(label="Save debug crops to disk", value=False)
551
+ ],
552
+ outputs=[
553
+ gr.Image(type="pil", label="Debug Visualization (Green=Region, Red=Lines)"),
554
+ gr.Textbox(label="Extracted Text", lines=10),
555
+ gr.Textbox(label="Processing Logs", lines=15)
556
+ ],
557
+ title="📝 Handwritten Text Recognition (HTR) with Enhanced Debugging",
558
+ description="""
559
+ Upload an image of a handwritten document.
560
 
561
+ **Visualization Key:**
562
+ - 🟩 **Green Box:** The broad region identified as containing text (original bounds).
563
+ - 🟥 **Red Box:** The specific line of text sent to the OCR engine (with coordinate validation).
564
 
565
+ **Improvements:**
566
+ - Fixed coordinate rounding (eliminates truncation errors)
567
+ - Added 10px padding to region crops (reduces edge effects)
568
+ - Coordinate validation (ensures all boxes are within image bounds)
569
+ - Enhanced logging with detailed coordinate tracking
570
+ - Optional debug crop saving
571
+ """,
572
+ flagging_mode="never",
573
+ theme=gr.themes.Soft()
574
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
+ if __name__ == "__main__":
577
+ logger.info("Launching Gradio interface...")
578
+ demo.launch()
 
579
 
 
 
580
 
 
 
581
 
 
582
 
 
 
 
583
 
 
 
 
 
584
 
 
585
 
 
586
 
 
 
587
 
 
 
 
 
588
 
 
 
 
 
 
589
 
 
 
 
590
 
 
591
 
 
 
 
 
592
 
593
 
594
 
 
596
 
597
 
598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599