ocr / app.py
iammraat's picture
Update app.py
55d6595 verified
# import gradio as gr
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# import torch
# from PIL import Image
# # --- Model Setup ---
# # We load the model outside the inference function to cache it on startup
# MODEL_ID = "microsoft/trocr-base-handwritten"
# print(f"Loading {MODEL_ID}...")
# processor = TrOCRProcessor.from_pretrained(MODEL_ID)
# model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID)
# # Check for GPU (Free Spaces are usually CPU-only, but this handles upgrades)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# print(f"Model loaded on device: {device}")
# # --- Inference Function ---
# def process_image(image):
# if image is None:
# return "Please upload an image."
# try:
# # 1. Convert to RGB (standardizes input)
# image = image.convert("RGB")
# # 2. Preprocess
# pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
# # 3. Generate text
# generated_ids = model.generate(pixel_values)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return generated_text
# except Exception as e:
# return f"Error: {str(e)}"
# # --- Gradio Interface ---
# # Using the Blocks API for a clean layout
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown(
# """
# # ✍️ Handwritten Text Recognition
# Using Microsoft's **TrOCR Small** model. Upload a handwritten note to transcribe it.
# """
# )
# with gr.Row():
# with gr.Column():
# input_img = gr.Image(type="pil", label="Upload Image")
# submit_btn = gr.Button("Transcribe", variant="primary")
# with gr.Column():
# output_text = gr.Textbox(label="Result", interactive=False)
# # Examples help users test it immediately without uploading their own file
# # (Uncomment the list below if you upload example images to your repo)
# # gr.Examples(["sample1.jpg"], inputs=input_img)
# submit_btn.click(fn=process_image, inputs=input_img, outputs=output_text)
# # Launch for Spaces
# if __name__ == "__main__":
# demo.launch()
# import gradio as gr
# import torch
# import numpy as np
# import cv2
# from PIL import Image
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# from craft_text_detector import Craft
# # ==========================================
# # 🔧 PATCH 1: Fix Torchvision Compatibility
# # ==========================================
# import torchvision.models.vgg
# if not hasattr(torchvision.models.vgg, 'model_urls'):
# torchvision.models.vgg.model_urls = {
# 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
# }
# # ==========================================
# # 🔧 PATCH 2: The "Ratio Net" Logic Fix
# # ==========================================
# import craft_text_detector.craft_utils as craft_utils_module
# def fixed_adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
# if not polys:
# return []
# adjusted = []
# for poly in polys:
# if poly is None or len(poly) == 0:
# continue
# # Convert to numpy and reshape
# p = np.array(poly).reshape(-1, 2)
# # Scale correctly using ratio_net
# p[:, 0] *= (ratio_w * ratio_net)
# p[:, 1] *= (ratio_h * ratio_net)
# adjusted.append(p)
# return adjusted
# craft_utils_module.adjustResultCoordinates = fixed_adjustResultCoordinates
# # ==========================================
# # --- 1. SETUP MODEL (Switched to BASE for stability) ---
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Loading TrOCR-Base on {device}...")
# # We use the 'base' model because 'small' hallucinates Wikipedia text on tight crops
# MODEL_ID = "microsoft/trocr-base-handwritten"
# processor = TrOCRProcessor.from_pretrained(MODEL_ID)
# model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device).eval()
# print("Loading CRAFT...")
# craft = Craft(output_dir=None, crop_type="box", cuda=(device == "cuda"))
# # --- 2. HELPER FUNCTIONS ---
# def get_sorted_boxes(boxes):
# """Sorts boxes top-to-bottom (lines), then left-to-right."""
# if not boxes: return []
# items = []
# for box in boxes:
# cy = np.mean(box[:, 1])
# cx = np.mean(box[:, 0])
# items.append((cy, cx, box))
# # Sort by line (approx 20px tolerance) then by column
# items.sort(key=lambda x: (int(x[0] // 20), x[1]))
# return [x[2] for x in items]
# def process_image(image):
# if image is None:
# return None, [], "Please upload an image."
# # Convert to standard RGB Numpy array
# # We use the FULL resolution image (no resizing) to keep text sharp
# image_np = np.array(image.convert("RGB"))
# # 1. DETECT
# # The patch ensures coordinates map perfectly to this full-res image
# prediction = craft.detect_text(image_np)
# boxes = prediction.get("boxes", [])
# if not boxes:
# return image, [], "No text detected."
# sorted_boxes = get_sorted_boxes(boxes)
# annotated_img = image_np.copy()
# results = []
# debug_crops = []
# # 2. PROCESS BOXES
# for box in sorted_boxes:
# box_int = box.astype(np.int32)
# # Draw the box (Visual verification)
# cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 3)
# # --- CROP WITH PADDING (Crucial Fix) ---
# # TrOCR needs 'breathing room' or it hallucinates.
# PADDING = 10
# x_min = max(0, np.min(box_int[:, 0]) - PADDING)
# x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
# y_min = max(0, np.min(box_int[:, 1]) - PADDING)
# y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
# # Skip noise
# if (x_max - x_min) < 20 or (y_max - y_min) < 10:
# continue
# crop = image_np[y_min:y_max, x_min:x_max]
# # Convert to PIL for Model
# pil_crop = Image.fromarray(crop)
# # Add to debug gallery so user can see what the model sees
# debug_crops.append(pil_crop)
# # 3. RECOGNIZE
# with torch.no_grad():
# pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
# generated_ids = model.generate(pixel_values)
# text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# if text.strip():
# results.append(text)
# full_text = "\n".join(results)
# return Image.fromarray(annotated_img), debug_crops, full_text
# # --- 3. GRADIO UI ---
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown("# 📝 Robust Handwritten OCR (Base Model)")
# gr.Markdown("Includes padding and a stronger model to prevent hallucinations.")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(type="pil", label="Upload Image")
# btn = gr.Button("Transcribe", variant="primary")
# with gr.Column(scale=1):
# output_img = gr.Image(label="Detections")
# output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
# with gr.Row():
# # Gallery to check if crops are valid or empty
# crop_gallery = gr.Gallery(label="Debug: See what the model sees (Crops)", columns=6, height=200)
# btn.click(process_image, input_img, [output_img, crop_gallery, output_txt])
# if __name__ == "__main__":
# demo.launch()
# import gradio as gr
# import torch
# import numpy as np
# import cv2
# from PIL import Image
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# from paddleocr import PaddleOCR
# # --- 1. SETUP TR-OCR (Recognition) ---
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Loading TrOCR on {device}...")
# processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
# model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
# # --- 2. SETUP PADDLEOCR (Detection Only) ---
# print("Loading PaddleOCR (DBNet)...")
# # We load the detector but we will bypass the main .ocr() method to avoid bugs
# detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
# def get_sorted_boxes(boxes):
# """Sorts boxes top-to-bottom (lines), then left-to-right."""
# if boxes is None or len(boxes) == 0:
# return []
# items = []
# for box in boxes:
# # Paddle returns boxes as numpy arrays or lists
# box = np.array(box).astype(np.float32)
# cy = np.mean(box[:, 1])
# cx = np.mean(box[:, 0])
# items.append((cy, cx, box))
# # Sort by Y (line tolerance 20px) then X
# items.sort(key=lambda x: (int(x[0] // 20), x[1]))
# return [x[2] for x in items]
# def process_image(image):
# if image is None:
# return None, [], "Please upload an image."
# # Convert to standard RGB Numpy array
# image_np = np.array(image.convert("RGB"))
# # ============================================================
# # 🔴 FIX: Direct Detection Bypass
# # ============================================================
# # The standard 'detector.ocr()' method has a bug in the current
# # version that crashes when checking "if not boxes".
# # We call the internal 'text_detector' directly to skip that check.
# try:
# dt_boxes, _ = detector.text_detector(image_np)
# except Exception as e:
# return image, [], f"Detection Error: {str(e)}"
# if dt_boxes is None or len(dt_boxes) == 0:
# return image, [], "No text detected."
# # dt_boxes is already a numpy array of coordinates
# sorted_boxes = get_sorted_boxes(dt_boxes)
# annotated_img = image_np.copy()
# results = []
# debug_crops = []
# # Process Boxes
# for box in sorted_boxes:
# box_int = box.astype(np.int32)
# # Draw Box (Red, thickness 2)
# cv2.polylines(annotated_img, [box_int], True, (255, 0, 0), 2)
# # Crop with Padding (Prevents TrOCR Hallucinations)
# PADDING = 10
# x_min = max(0, np.min(box_int[:, 0]) - PADDING)
# x_max = min(image_np.shape[1], np.max(box_int[:, 0]) + PADDING)
# y_min = max(0, np.min(box_int[:, 1]) - PADDING)
# y_max = min(image_np.shape[0], np.max(box_int[:, 1]) + PADDING)
# # Skip noise
# if (x_max - x_min) < 15 or (y_max - y_min) < 10:
# continue
# crop = image_np[y_min:y_max, x_min:x_max]
# pil_crop = Image.fromarray(crop)
# debug_crops.append(pil_crop)
# # Recognition (TrOCR)
# with torch.no_grad():
# pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
# generated_ids = model.generate(pixel_values)
# text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# if text.strip():
# results.append(text)
# full_text = "\n".join(results)
# return Image.fromarray(annotated_img), debug_crops, full_text
# # --- UI ---
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown("# ⚡ PaddleOCR + TrOCR (Robust)")
# gr.Markdown("Using direct DBNet inference to avoid library bugs.")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(type="pil", label="Upload Image")
# btn = gr.Button("Transcribe", variant="primary")
# with gr.Column(scale=1):
# output_img = gr.Image(label="Detections (Paddle)")
# output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
# with gr.Row():
# gallery = gr.Gallery(label="Line Crops (Debug)", columns=6, height=200)
# btn.click(process_image, input_img, [output_img, gallery, output_txt])
# if __name__ == "__main__":
# demo.launch()
# import gradio as gr
# import torch
# import numpy as np
# import cv2
# from PIL import Image
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# from paddleocr import PaddleOCR
# # --- 1. SETUP TR-OCR ---
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Loading TrOCR on {device}...")
# processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
# model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
# # --- 2. SETUP PADDLEOCR ---
# print("Loading PaddleOCR...")
# # High resolution to catch faint text
# detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False,
# det_limit_side_len=2500, det_db_thresh=0.1, det_db_box_thresh=0.3)
# # ==========================================
# # 🧠 LOGIC FIX 1: REMOVE NESTED BOXES
# # ==========================================
# def calculate_overlap_area(box1, box2):
# """Calculates the intersection area between two boxes."""
# x1 = max(box1[0], box2[0])
# y1 = max(box1[1], box2[1])
# x2 = min(box1[2], box2[2])
# y2 = min(box1[3], box2[3])
# if x2 < x1 or y2 < y1:
# return 0.0
# return (x2 - x1) * (y2 - y1)
# def filter_nested_boxes(boxes, containment_thresh=0.80):
# """
# Removes boxes that are mostly contained within other larger boxes.
# """
# if not boxes: return []
# # Convert all to [x1, y1, x2, y2, area]
# active = []
# for b in boxes:
# area = (b[2] - b[0]) * (b[3] - b[1])
# active.append(list(b) + [area])
# # Sort by area (Largest to Smallest) - Crucial!
# # We want to keep the big 'parent' box and delete the small 'child' box.
# active.sort(key=lambda x: x[4], reverse=True)
# final_boxes = []
# for i, current in enumerate(active):
# is_nested = False
# curr_area = current[4]
# # Check against all boxes we've already accepted (which are bigger/same size)
# for kept in final_boxes:
# overlap = calculate_overlap_area(current, kept)
# # Check if 'current' is inside 'kept'
# # If >80% of current box is covered by kept box, it's a duplicate/nested box
# if (overlap / curr_area) > containment_thresh:
# is_nested = True
# break
# if not is_nested:
# final_boxes.append(current[:4]) # Store only coord, drop area
# return final_boxes
# # ==========================================
# # 🧠 LOGIC FIX 2: MERGE WORDS INTO LINES
# # ==========================================
# def merge_boxes_into_lines(raw_boxes, y_thresh=30):
# if raw_boxes is None or len(raw_boxes) == 0:
# return []
# # 1. Convert raw polygons to Axis-Aligned Rectangles
# rects = []
# for box in raw_boxes:
# box = np.array(box).astype(np.float32)
# x1 = np.min(box[:, 0])
# y1 = np.min(box[:, 1])
# x2 = np.max(box[:, 0])
# y2 = np.max(box[:, 1])
# rects.append([x1, y1, x2, y2])
# # 🔴 STEP 2: Filter Nested Boxes (Remove the 'child' boxes)
# rects = filter_nested_boxes(rects)
# # 3. Sort by Y center
# rects.sort(key=lambda r: (r[1] + r[3]) / 2)
# merged_lines = []
# while rects:
# current_line = [rects.pop(0)]
# line_y_center = (current_line[0][1] + current_line[0][3]) / 2
# remaining = []
# for r in rects:
# r_y_center = (r[1] + r[3]) / 2
# # If Y-center is close (same horizontal line)
# if abs(r_y_center - line_y_center) < y_thresh:
# current_line.append(r)
# else:
# remaining.append(r)
# rects = remaining
# # 4. Create Line Box
# lx1 = min(r[0] for r in current_line)
# ly1 = min(r[1] for r in current_line)
# lx2 = max(r[2] for r in current_line)
# ly2 = max(r[3] for r in current_line)
# merged_lines.append([lx1, ly1, lx2, ly2])
# # Final Sort by Y
# merged_lines.sort(key=lambda r: r[1])
# return merged_lines
# def process_image(image):
# if image is None: return None, [], "Please upload an image."
# image_np = np.array(image.convert("RGB"))
# # DETECT
# try:
# dt_boxes, _ = detector.text_detector(image_np)
# except Exception as e:
# return image, [], f"Detection Error: {str(e)}"
# if dt_boxes is None or len(dt_boxes) == 0:
# return image, [], "No text detected."
# # PROCESS (Filter Nested -> Merge Lines)
# line_boxes = merge_boxes_into_lines(dt_boxes)
# annotated_img = image_np.copy()
# results = []
# debug_crops = []
# for box in line_boxes:
# x1, y1, x2, y2 = map(int, box)
# # Filter Noise
# if (x2 - x1) < 20 or (y2 - y1) < 15:
# continue
# # Draw (Green)
# cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
# # PADDING
# PAD = 10
# h, w, _ = image_np.shape
# x1 = max(0, x1 - PAD)
# y1 = max(0, y1 - PAD)
# x2 = min(w, x2 + PAD)
# y2 = min(h, y2 + PAD)
# crop = image_np[y1:y2, x1:x2]
# pil_crop = Image.fromarray(crop)
# debug_crops.append(pil_crop)
# # RECOGNIZE
# with torch.no_grad():
# pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
# generated_ids = model.generate(pixel_values)
# text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# if text.strip():
# results.append(text)
# full_text = "\n".join(results)
# return Image.fromarray(annotated_img), debug_crops, full_text
# # --- UI ---
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown("# ⚡ Smart Line-Level OCR (Cleaned)")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(type="pil", label="Upload Image")
# btn = gr.Button("Transcribe", variant="primary")
# with gr.Column(scale=1):
# output_img = gr.Image(label="Cleaned Lines (Green Boxes)")
# output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
# with gr.Row():
# gallery = gr.Gallery(label="Final Line Crops", columns=4, height=200)
# btn.click(process_image, input_img, [output_img, gallery, output_txt])
# if __name__ == "__main__":
# demo.launch()
# import gradio as gr
# import torch
# import numpy as np
# import cv2
# from PIL import Image
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# from paddleocr import PaddleOCR
# # Setup
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Loading TrOCR on {device}...")
# processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
# model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
# print("Loading PaddleOCR...")
# detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False,
# det_limit_side_len=2500, det_db_thresh=0.1, det_db_box_thresh=0.3)
# def calculate_iou(box1, box2):
# """Calculate Intersection over Union"""
# x1 = max(box1[0], box2[0])
# y1 = max(box1[1], box2[1])
# x2 = min(box1[2], box2[2])
# y2 = min(box1[3], box2[3])
# if x2 < x1 or y2 < y1:
# return 0.0
# intersection = (x2 - x1) * (y2 - y1)
# area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
# area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
# return intersection / min(area1, area2)
# def remove_nested_boxes(boxes, iou_thresh=0.7):
# """Remove boxes that are nested inside others"""
# if len(boxes) == 0:
# return []
# # Add area to each box
# boxes_with_area = []
# for b in boxes:
# area = (b[2] - b[0]) * (b[3] - b[1])
# boxes_with_area.append((*b, area))
# # Sort by area descending (keep larger boxes)
# boxes_with_area.sort(key=lambda x: x[4], reverse=True)
# keep = []
# for i, current in enumerate(boxes_with_area):
# should_keep = True
# curr_box = current[:4]
# for kept in keep:
# iou = calculate_iou(curr_box, kept)
# if iou > iou_thresh:
# should_keep = False
# break
# if should_keep:
# keep.append(curr_box)
# return keep
# def merge_boxes_into_lines(raw_boxes, y_overlap_thresh=0.5, x_gap_thresh=100):
# """Merge boxes into lines with better horizontal merging"""
# if raw_boxes is None or len(raw_boxes) == 0:
# return []
# # Convert polygons to rectangles
# rects = []
# for box in raw_boxes:
# box = np.array(box).astype(np.float32)
# x1, y1 = np.min(box[:, 0]), np.min(box[:, 1])
# x2, y2 = np.max(box[:, 0]), np.max(box[:, 1])
# rects.append([x1, y1, x2, y2])
# # Remove nested boxes
# rects = remove_nested_boxes(rects)
# if len(rects) == 0:
# return []
# # Sort by Y position
# rects.sort(key=lambda r: r[1])
# # Group into lines based on Y overlap
# lines = []
# current_line = [rects[0]]
# for rect in rects[1:]:
# # Check if rect belongs to current line
# line_y1 = min(r[1] for r in current_line)
# line_y2 = max(r[3] for r in current_line)
# line_height = line_y2 - line_y1
# rect_y1, rect_y2 = rect[1], rect[3]
# rect_height = rect_y2 - rect_y1
# # Calculate vertical overlap
# overlap_y1 = max(line_y1, rect_y1)
# overlap_y2 = min(line_y2, rect_y2)
# overlap = max(0, overlap_y2 - overlap_y1)
# # If significant vertical overlap, it's the same line
# if overlap > y_overlap_thresh * min(line_height, rect_height):
# current_line.append(rect)
# else:
# # Save current line and start new one
# lines.append(current_line)
# current_line = [rect]
# lines.append(current_line)
# # Merge boxes in each line
# merged = []
# for line in lines:
# # Sort line boxes left to right
# line.sort(key=lambda r: r[0])
# # Merge horizontally close boxes
# merged_line = [line[0]]
# for rect in line[1:]:
# last = merged_line[-1]
# # If close horizontally, merge
# if rect[0] - last[2] < x_gap_thresh:
# merged_line[-1] = [
# min(last[0], rect[0]),
# min(last[1], rect[1]),
# max(last[2], rect[2]),
# max(last[3], rect[3])
# ]
# else:
# merged_line.append(rect)
# # Final merge: combine all boxes in line into one
# x1 = min(r[0] for r in merged_line)
# y1 = min(r[1] for r in merged_line)
# x2 = max(r[2] for r in merged_line)
# y2 = max(r[3] for r in merged_line)
# merged.append([x1, y1, x2, y2])
# # Sort by Y
# merged.sort(key=lambda r: r[1])
# return merged
# def process_image(image):
# if image is None:
# return None, [], "Please upload an image."
# image_np = np.array(image.convert("RGB"))
# try:
# dt_boxes, _ = detector.text_detector(image_np)
# except Exception as e:
# return image, [], f"Detection Error: {str(e)}"
# if dt_boxes is None or len(dt_boxes) == 0:
# return image, [], "No text detected."
# line_boxes = merge_boxes_into_lines(dt_boxes)
# annotated_img = image_np.copy()
# results = []
# debug_crops = []
# for box in line_boxes:
# x1, y1, x2, y2 = map(int, box)
# if (x2 - x1) < 20 or (y2 - y1) < 15:
# continue
# cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
# PAD = 10
# h, w, _ = image_np.shape
# x1 = max(0, x1 - PAD)
# y1 = max(0, y1 - PAD)
# x2 = min(w, x2 + PAD)
# y2 = min(h, y2 + PAD)
# crop = image_np[y1:y2, x1:x2]
# pil_crop = Image.fromarray(crop)
# debug_crops.append(pil_crop)
# with torch.no_grad():
# pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
# generated_ids = model.generate(pixel_values)
# text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# if text.strip():
# results.append(text)
# full_text = "\n".join(results)
# return Image.fromarray(annotated_img), debug_crops, full_text
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown("# ⚡ Smart Line-Level OCR (Fixed)")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(type="pil", label="Upload Image")
# btn = gr.Button("Transcribe", variant="primary")
# with gr.Column(scale=1):
# output_img = gr.Image(label="Detected Lines")
# output_txt = gr.Textbox(label="Extracted Text", lines=15, show_copy_button=True)
# with gr.Row():
# gallery = gr.Gallery(label="Line Crops", columns=4, height=200)
# btn.click(process_image, input_img, [output_img, gallery, output_txt])
# if __name__ == "__main__":
# demo.launch()
#https://github.com/czczup/FAST
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from paddleocr import PaddleOCR
import pandas as pd
# --- 1. SETUP TR-OCR ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading TrOCR on {device}...")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device).eval()
# --- 2. SETUP PADDLEOCR ---
print("Loading PaddleOCR...")
# High resolution settings to detect faint text
detector = PaddleOCR(use_angle_cls=True, lang='en', show_log=False,
det_limit_side_len=2500, det_db_thresh=0.1, det_db_box_thresh=0.3)
# ==========================================
# 🧠 LOGIC: INTERSECTION OVER UNION (IOU)
# ==========================================
def calculate_iou_containment(box1, box2):
"""
Calculates how much of box1 is inside box2.
Returns: ratio (0.0 to 1.0)
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
if x2 < x1 or y2 < y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
return intersection / area1
def filter_nested_boxes(boxes, containment_thresh=0.85):
"""
Removes boxes that are mostly contained within other larger boxes.
"""
if not boxes: return []
# [x1, y1, x2, y2, area]
active = []
for b in boxes:
area = (b[2] - b[0]) * (b[3] - b[1])
active.append(list(b) + [area])
# Sort by Area descending (Biggest first)
active.sort(key=lambda x: x[4], reverse=True)
final_boxes = []
for current in active:
is_nested = False
curr_box = current[:4]
# Check if this box is inside any bigger box we already kept
for kept in final_boxes:
overlap_ratio = calculate_iou_containment(curr_box, kept)
if overlap_ratio > containment_thresh:
is_nested = True
break
if not is_nested:
final_boxes.append(curr_box)
return final_boxes
# ==========================================
# 🧠 LOGIC: STRICT LINE MERGING
# ==========================================
def merge_boxes_into_lines(raw_boxes, log_data):
"""
Merges boxes horizontally but prevents vertical merging.
"""
if raw_boxes is None or len(raw_boxes) == 0:
return []
# 1. Convert to Rects
rects = []
for box in raw_boxes:
box = np.array(box).astype(np.float32)
x1, y1 = np.min(box[:, 0]), np.min(box[:, 1])
x2, y2 = np.max(box[:, 0]), np.max(box[:, 1])
rects.append([x1, y1, x2, y2])
log_data.append(f"Raw Detections: {len(rects)} boxes found.")
# 2. Filter Nested
rects = filter_nested_boxes(rects)
log_data.append(f"After Cleaning Nested: {len(rects)} boxes remain.")
# 3. Sort by Y-Center (Top to Bottom)
rects.sort(key=lambda r: (r[1] + r[3]) / 2)
lines = []
while rects:
# Start a new line with the highest remaining box
current_line = [rects.pop(0)]
# Calculate the dynamic "height" of this line based on the first word
ref_h = current_line[0][3] - current_line[0][1]
ref_y_center = (current_line[0][1] + current_line[0][3]) / 2
# Look for other words on this SAME line
# STRICT RULE: A box is on the same line ONLY if its Y-center
# is within 50% of the reference box's height.
vertical_tolerance = ref_h * 0.5
remaining_rects = []
for r in rects:
r_y_center = (r[1] + r[3]) / 2
if abs(r_y_center - ref_y_center) < vertical_tolerance:
current_line.append(r)
else:
remaining_rects.append(r)
rects = remaining_rects
# Sort words in this line left-to-right
current_line.sort(key=lambda r: r[0])
# 4. Merge the horizontal group into ONE box
lx1 = min(r[0] for r in current_line)
ly1 = min(r[1] for r in current_line)
lx2 = max(r[2] for r in current_line)
ly2 = max(r[3] for r in current_line)
lines.append([lx1, ly1, lx2, ly2])
# Final Sort by Y
lines.sort(key=lambda r: r[1])
log_data.append(f"Final Merged Lines: {len(lines)} lines created.")
return lines
def process_image(image):
logs = [] # Store debug messages here
if image is None:
return None, [], "Please upload an image.", "No logs."
image_np = np.array(image.convert("RGB"))
# DETECT
try:
dt_boxes, _ = detector.text_detector(image_np)
except Exception as e:
return image, [], f"Detection Error: {str(e)}", "\n".join(logs)
if dt_boxes is None or len(dt_boxes) == 0:
return image, [], "No text detected.", "\n".join(logs)
# PROCESS
line_boxes = merge_boxes_into_lines(dt_boxes, logs)
annotated_img = image_np.copy()
results = []
debug_crops = []
# Log the final box coordinates for inspection
logs.append("\n--- Final Box Coordinates ---")
for i, box in enumerate(line_boxes):
x1, y1, x2, y2 = map(int, box)
logs.append(f"Line {i+1}: x={x1}, y={y1}, w={x2-x1}, h={y2-y1}")
# Filter Noise
if (x2 - x1) < 20 or (y2 - y1) < 15:
logs.append(f"-> Skipped Line {i+1} (Too Small/Noise)")
continue
# Draw (Green)
cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
# PADDING
PAD = 10
h, w, _ = image_np.shape
x1 = max(0, x1 - PAD)
y1 = max(0, y1 - PAD)
x2 = min(w, x2 + PAD)
y2 = min(h, y2 + PAD)
crop = image_np[y1:y2, x1:x2]
pil_crop = Image.fromarray(crop)
debug_crops.append(pil_crop)
# RECOGNIZE
with torch.no_grad():
pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(device)
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if text.strip():
results.append(text)
full_text = "\n".join(results)
return Image.fromarray(annotated_img), debug_crops, full_text, "\n".join(logs)
# --- UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ⚡ Smart Line-Level OCR (Debug Mode)")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Upload Image")
btn = gr.Button("Transcribe", variant="primary")
with gr.Column(scale=1):
with gr.Tabs():
with gr.Tab("Visualization"):
output_img = gr.Image(label="Detected Lines")
with gr.Tab("Extracted Text"):
output_txt = gr.Textbox(label="Result", lines=15, show_copy_button=True)
with gr.Tab("Debug Logs"):
# CHANGED HERE: Uses Textbox instead of Code to avoid version errors
log_output = gr.Textbox(label="Processing Logs", lines=20, interactive=False)
with gr.Row():
gallery = gr.Gallery(label="Final Line Crops", columns=4, height=200)
btn.click(process_image, input_img, [output_img, gallery, output_txt, log_output])
if __name__ == "__main__":
demo.launch()