iammraat commited on
Commit
b02cd5a
·
verified ·
1 Parent(s): 80cf5fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -39
app.py CHANGED
@@ -310,6 +310,164 @@
310
  # demo.launch(server_name="0.0.0.0", server_port=7860)
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  import gradio as gr
314
  import torch
315
  import numpy as np
@@ -319,7 +477,7 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
319
  from craft_text_detector import Craft
320
 
321
  # ==========================================
322
- # 🔧 PATCH 1: Fix Torchvision (From your code)
323
  # ==========================================
324
  import torchvision.models.vgg
325
  if not hasattr(torchvision.models.vgg, 'model_urls'):
@@ -328,7 +486,7 @@ if not hasattr(torchvision.models.vgg, 'model_urls'):
328
  }
329
 
330
  # ==========================================
331
- # 🔧 PATCH 2: The Logic Fix (Ratio Net)
332
  # ==========================================
333
  import craft_text_detector.craft_utils as craft_utils_module
334
 
@@ -341,11 +499,10 @@ def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
341
  if poly is None or len(poly) == 0:
342
  continue
343
 
344
- # Safe numpy conversion
345
  p = np.array(poly).reshape(-1, 2)
346
 
347
- # CRITICAL FIX: Multiply by ratio_net (defaults to 2)
348
- # This scales the 1/2 size heatmap output back to full image size
349
  p[:, 0] *= (ratio_w * ratio_net)
350
  p[:, 1] *= (ratio_h * ratio_net)
351
 
@@ -353,21 +510,24 @@ def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
353
 
354
  return adjusted
355
 
356
- # Apply the patch
357
  craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
358
  # ==========================================
359
 
360
- # --- Load TrOCR (Recognition) ---
 
361
  device = "cuda" if torch.cuda.is_available() else "cpu"
362
- print(f"Loading TrOCR on {device}...")
363
- processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
364
- model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten').to(device).eval()
 
 
 
365
 
366
- # --- Load CRAFT (Detection) ---
367
  print("Loading CRAFT...")
368
- # crop_type="box" ensures we get clean rectangles
369
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
370
 
 
 
371
  def get_sorted_boxes(boxes):
372
  """Sorts boxes top-to-bottom (lines), then left-to-right."""
373
  if not boxes: return []
@@ -377,54 +537,60 @@ def get_sorted_boxes(boxes):
377
  cx = np.mean(box[:, 0])
378
  items.append((cy, cx, box))
379
 
380
- # Sort by Y (grouping by 40px lines) then X
381
- items.sort(key=lambda x: (int(x[0] // 40), x[1]))
382
  return [x[2] for x in items]
383
 
384
  def process_image(image):
385
  if image is None:
386
- return None, "Please upload an image."
387
 
388
- # Convert to numpy
 
389
  image_np = np.array(image.convert("RGB"))
390
 
391
  # 1. DETECT
392
- # The patch we added above will now auto-multiply coordinates by 2 * ratio
393
- # fixing the "tiny box" issue.
394
  prediction = craft.detect_text(image_np)
395
  boxes = prediction.get("boxes", [])
396
 
397
  if not boxes:
398
- return image, "No text detected."
399
 
400
- # 2. VISUALIZE & CROP
401
  sorted_boxes = get_sorted_boxes(boxes)
402
  annotated_img = image_np.copy()
403
  results = []
 
404
 
 
405
  for box in sorted_boxes:
406
- # Cast to int for drawing
407
  box_int = box.astype(np.int32)
408
 
409
- # Draw on image (Blue, thickness 3)
410
  cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
411
 
412
- # Get Crop Coordinates
413
- x_min = max(0, np.min(box_int[:, 0]))
414
- x_max = min(image_np.shape[1], np.max(box_int[:, 0]))
415
- y_min = max(0, np.min(box_int[:, 1]))
416
- y_max = min(image_np.shape[0], np.max(box_int[:, 1]))
 
 
 
417
 
418
- # Filter noise
419
- if (x_max - x_min) < 10 or (y_max - y_min) < 10:
420
  continue
421
 
422
  crop = image_np[y_min:y_max, x_min:x_max]
423
- if crop.size == 0: continue
424
 
 
425
  pil_crop = Image.fromarray(crop)
426
 
427
- # 3. RECOGNIZE (TrOCR)
 
 
 
428
  with torch.no_grad():
429
  pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
430
  generated_ids = model.generate(pixel_values)
@@ -434,22 +600,28 @@ def process_image(image):
434
  results.append(text)
435
 
436
  full_text = "\n".join(results)
437
- return Image.fromarray(annotated_img), full_text
 
438
 
439
- # --- Gradio UI ---
440
- with gr.Blocks(title="Handwritten OCR Fixed") as demo:
441
- gr.Markdown("# 📝 Handwritten OCR (Fixed Pipeline)")
 
442
 
443
  with gr.Row():
444
- with gr.Column():
445
  input_img = gr.Image(type="pil", label="Upload Image")
446
  btn = gr.Button("Transcribe", variant="primary")
447
 
448
- with gr.Column():
449
  output_img = gr.Image(label="Detections")
450
- output_txt = gr.Textbox(label="Result", lines=20)
 
 
 
 
451
 
452
- btn.click(process_image, input_img, [output_img, output_txt])
453
 
454
  if __name__ == "__main__":
455
  demo.launch()
 
310
  # demo.launch(server_name="0.0.0.0", server_port=7860)
311
 
312
 
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+ # import gradio as gr
321
+ # import torch
322
+ # import numpy as np
323
+ # import cv2
324
+ # from PIL import Image
325
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
326
+ # from craft_text_detector import Craft
327
+
328
+ # # ==========================================
329
+ # # 🔧 PATCH 1: Fix Torchvision (From your code)
330
+ # # ==========================================
331
+ # import torchvision.models.vgg
332
+ # if not hasattr(torchvision.models.vgg, 'model_urls'):
333
+ # torchvision.models.vgg.model_urls = {
334
+ # 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
335
+ # }
336
+
337
+ # # ==========================================
338
+ # # 🔧 PATCH 2: The Logic Fix (Ratio Net)
339
+ # # ==========================================
340
+ # import craft_text_detector.craft_utils as craft_utils_module
341
+
342
+ # def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
343
+ # if not polys:
344
+ # return []
345
+
346
+ # adjusted = []
347
+ # for poly in polys:
348
+ # if poly is None or len(poly) == 0:
349
+ # continue
350
+
351
+ # # Safe numpy conversion
352
+ # p = np.array(poly).reshape(-1, 2)
353
+
354
+ # # CRITICAL FIX: Multiply by ratio_net (defaults to 2)
355
+ # # This scales the 1/2 size heatmap output back to full image size
356
+ # p[:, 0] *= (ratio_w * ratio_net)
357
+ # p[:, 1] *= (ratio_h * ratio_net)
358
+
359
+ # adjusted.append(p)
360
+
361
+ # return adjusted
362
+
363
+ # # Apply the patch
364
+ # craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
365
+ # # ==========================================
366
+
367
+ # # --- Load TrOCR (Recognition) ---
368
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
369
+ # print(f"Loading TrOCR on {device}...")
370
+ # processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
371
+ # model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten').to(device).eval()
372
+
373
+ # # --- Load CRAFT (Detection) ---
374
+ # print("Loading CRAFT...")
375
+ # # crop_type="box" ensures we get clean rectangles
376
+ # craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
377
+
378
+ # def get_sorted_boxes(boxes):
379
+ # """Sorts boxes top-to-bottom (lines), then left-to-right."""
380
+ # if not boxes: return []
381
+ # items = []
382
+ # for box in boxes:
383
+ # cy = np.mean(box[:, 1])
384
+ # cx = np.mean(box[:, 0])
385
+ # items.append((cy, cx, box))
386
+
387
+ # # Sort by Y (grouping by 40px lines) then X
388
+ # items.sort(key=lambda x: (int(x[0] // 40), x[1]))
389
+ # return [x[2] for x in items]
390
+
391
+ # def process_image(image):
392
+ # if image is None:
393
+ # return None, "Please upload an image."
394
+
395
+ # # Convert to numpy
396
+ # image_np = np.array(image.convert("RGB"))
397
+
398
+ # # 1. DETECT
399
+ # # The patch we added above will now auto-multiply coordinates by 2 * ratio
400
+ # # fixing the "tiny box" issue.
401
+ # prediction = craft.detect_text(image_np)
402
+ # boxes = prediction.get("boxes", [])
403
+
404
+ # if not boxes:
405
+ # return image, "No text detected."
406
+
407
+ # # 2. VISUALIZE & CROP
408
+ # sorted_boxes = get_sorted_boxes(boxes)
409
+ # annotated_img = image_np.copy()
410
+ # results = []
411
+
412
+ # for box in sorted_boxes:
413
+ # # Cast to int for drawing
414
+ # box_int = box.astype(np.int32)
415
+
416
+ # # Draw on image (Blue, thickness 3)
417
+ # cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
418
+
419
+ # # Get Crop Coordinates
420
+ # x_min = max(0, np.min(box_int[:, 0]))
421
+ # x_max = min(image_np.shape[1], np.max(box_int[:, 0]))
422
+ # y_min = max(0, np.min(box_int[:, 1]))
423
+ # y_max = min(image_np.shape[0], np.max(box_int[:, 1]))
424
+
425
+ # # Filter noise
426
+ # if (x_max - x_min) < 10 or (y_max - y_min) < 10:
427
+ # continue
428
+
429
+ # crop = image_np[y_min:y_max, x_min:x_max]
430
+ # if crop.size == 0: continue
431
+
432
+ # pil_crop = Image.fromarray(crop)
433
+
434
+ # # 3. RECOGNIZE (TrOCR)
435
+ # with torch.no_grad():
436
+ # pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
437
+ # generated_ids = model.generate(pixel_values)
438
+ # text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
439
+
440
+ # if text.strip():
441
+ # results.append(text)
442
+
443
+ # full_text = "\n".join(results)
444
+ # return Image.fromarray(annotated_img), full_text
445
+
446
+ # # --- Gradio UI ---
447
+ # with gr.Blocks(title="Handwritten OCR Fixed") as demo:
448
+ # gr.Markdown("# 📝 Handwritten OCR (Fixed Pipeline)")
449
+
450
+ # with gr.Row():
451
+ # with gr.Column():
452
+ # input_img = gr.Image(type="pil", label="Upload Image")
453
+ # btn = gr.Button("Transcribe", variant="primary")
454
+
455
+ # with gr.Column():
456
+ # output_img = gr.Image(label="Detections")
457
+ # output_txt = gr.Textbox(label="Result", lines=20)
458
+
459
+ # btn.click(process_image, input_img, [output_img, output_txt])
460
+
461
+ # if __name__ == "__main__":
462
+ # demo.launch()
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+
471
  import gradio as gr
472
  import torch
473
  import numpy as np
 
477
  from craft_text_detector import Craft
478
 
479
  # ==========================================
480
+ # 🔧 PATCH 1: Fix Torchvision Compatibility
481
  # ==========================================
482
  import torchvision.models.vgg
483
  if not hasattr(torchvision.models.vgg, 'model_urls'):
 
486
  }
487
 
488
  # ==========================================
489
+ # 🔧 PATCH 2: The "Ratio Net" Logic Fix
490
  # ==========================================
491
  import craft_text_detector.craft_utils as craft_utils_module
492
 
 
499
  if poly is None or len(poly) == 0:
500
  continue
501
 
502
+ # Convert to numpy and reshape
503
  p = np.array(poly).reshape(-1, 2)
504
 
505
+ # Scale correctly using ratio_net
 
506
  p[:, 0] *= (ratio_w * ratio_net)
507
  p[:, 1] *= (ratio_h * ratio_net)
508
 
 
510
 
511
  return adjusted
512
 
 
513
  craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
514
  # ==========================================
515
 
516
+
517
+ # --- 1. SETUP MODEL (Switched to BASE for stability) ---
518
  device = "cuda" if torch.cuda.is_available() else "cpu"
519
+ print(f"Loading TrOCR-Base on {device}...")
520
+
521
+ # We use the 'base' model because 'small' hallucinates Wikipedia text on tight crops
522
+ MODEL_ID = "microsoft/trocr-base-handwritten"
523
+ processor = TrOCRProcessor.from_pretrained(MODEL_ID)
524
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device).eval()
525
 
 
526
  print("Loading CRAFT...")
 
527
  craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
528
 
529
+
530
+ # --- 2. HELPER FUNCTIONS ---
531
  def get_sorted_boxes(boxes):
532
  """Sorts boxes top-to-bottom (lines), then left-to-right."""
533
  if not boxes: return []
 
537
  cx = np.mean(box[:, 0])
538
  items.append((cy, cx, box))
539
 
540
+ # Sort by line (approx 20px tolerance) then by column
541
+ items.sort(key=lambda x: (int(x[0] // 20), x[1]))
542
  return [x[2] for x in items]
543
 
544
  def process_image(image):
545
  if image is None:
546
+ return None, [], "Please upload an image."
547
 
548
+ # Convert to standard RGB Numpy array
549
+ # We use the FULL resolution image (no resizing) to keep text sharp
550
  image_np = np.array(image.convert("RGB"))
551
 
552
  # 1. DETECT
553
+ # The patch ensures coordinates map perfectly to this full-res image
 
554
  prediction = craft.detect_text(image_np)
555
  boxes = prediction.get("boxes", [])
556
 
557
  if not boxes:
558
+ return image, [], "No text detected."
559
 
 
560
  sorted_boxes = get_sorted_boxes(boxes)
561
  annotated_img = image_np.copy()
562
  results = []
563
+ debug_crops = []
564
 
565
+ # 2. PROCESS BOXES
566
  for box in sorted_boxes:
 
567
  box_int = box.astype(np.int32)
568
 
569
+ # Draw the box (Visual verification)
570
  cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
571
 
572
+ # --- CROP WITH PADDING (Crucial Fix) ---
573
+ # TrOCR needs 'breathing room' or it hallucinates.
574
+ PADDING = 10
575
+
576
+ x_min = max(0, np.min(box_int[:, 0]) - PADDING)
577
+ x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
578
+ y_min = max(0, np.min(box_int[:, 1]) - PADDING)
579
+ y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
580
 
581
+ # Skip noise
582
+ if (x_max - x_min) < 20 or (y_max - y_min) < 10:
583
  continue
584
 
585
  crop = image_np[y_min:y_max, x_min:x_max]
 
586
 
587
+ # Convert to PIL for Model
588
  pil_crop = Image.fromarray(crop)
589
 
590
+ # Add to debug gallery so user can see what the model sees
591
+ debug_crops.append(pil_crop)
592
+
593
+ # 3. RECOGNIZE
594
  with torch.no_grad():
595
  pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
596
  generated_ids = model.generate(pixel_values)
 
600
  results.append(text)
601
 
602
  full_text = "\n".join(results)
603
+
604
+ return Image.fromarray(annotated_img), debug_crops, full_text
605
 
606
+ # --- 3. GRADIO UI ---
607
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
608
+ gr.Markdown("# 📝 Robust Handwritten OCR (Base Model)")
609
+ gr.Markdown("Includes padding and a stronger model to prevent hallucinations.")
610
 
611
  with gr.Row():
612
+ with gr.Column(scale=1):
613
  input_img = gr.Image(type="pil", label="Upload Image")
614
  btn = gr.Button("Transcribe", variant="primary")
615
 
616
+ with gr.Column(scale=1):
617
  output_img = gr.Image(label="Detections")
618
+ output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
619
+
620
+ with gr.Row():
621
+ # Gallery to check if crops are valid or empty
622
+ crop_gallery = gr.Gallery(label="Debug: See what the model sees (Crops)", columns=6, height=200)
623
 
624
+ btn.click(process_image, input_img, [output_img, crop_gallery, output_txt])
625
 
626
  if __name__ == "__main__":
627
  demo.launch()