import gradio as gr import torch import numpy as np import cv2 from PIL import Image, ImageOps from transformers import TrOCRProcessor, VisionEncoderDecoderModel from paddleocr import PaddleOCR from scipy.signal import find_peaks # ========================================== # ⚙️ CONFIGURATION & MODEL LOADING # ========================================== print("--- SYSTEM STARTUP ---") # Force CPU to avoid CUDA overhead on CPU-only Spaces DEVICE = "cpu" print(f"-> Hardware Device: {DEVICE}") # 1. LOAD TR-OCR (Recognition) # We use the 'stage1' model which is often more robust for general handwriting print("-> Loading TrOCR Model...") processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(DEVICE).eval() # 2. LOAD PADDLEOCR (Detection) # 'structure_version' and generic settings tuned for recall (catch everything, filter later) print("-> Loading PaddleOCR Detector...") detector = PaddleOCR( use_angle_cls=True, lang='en', show_log=False, use_gpu=False, det_limit_side_len=2500, # High res for small text det_db_thresh=0.1, # Low threshold to catch faint ink det_db_box_thresh=0.3, det_db_unclip_ratio=1.6 ) print("--- SYSTEMS READY ---") # ========================================== # 🧠 CORE LOGIC: GEOMETRY UTILS # ========================================== def calculate_iou_containment(box1, box2): """ Calculates how much of box1 is inside box2. """ x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) if x2 < x1 or y2 < y1: return 0.0 intersection = (x2 - x1) * (y2 - y1) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + 1e-6 return intersection / area1 def get_vertical_overlap_ratio(box1, box2): """ Calculates vertical overlap between two boxes. Used to determine if words are on the same line. """ # y1, y2 are top, bottom y1_a, y2_a = box1[1], box1[3] y1_b, y2_b = box2[1], box2[3] intersection_start = max(y1_a, y1_b) intersection_end = min(y2_a, y2_b) if intersection_end < intersection_start: return 0.0 overlap_height = intersection_end - intersection_start min_height = min(y2_a - y1_a, y2_b - y1_b) + 1e-6 return overlap_height / min_height def filter_nested_boxes(boxes, containment_thresh=0.9): """ Removes small noise boxes inside larger real boxes. """ if not boxes: return [] # Add area to list: [x1, y1, x2, y2, area] active = [] for b in boxes: area = (b[2] - b[0]) * (b[3] - b[1]) active.append(list(b) + [area]) # Sort largest first active.sort(key=lambda x: x[4], reverse=True) final_boxes = [] for current in active: is_nested = False curr_box = current[:4] for kept in final_boxes: if calculate_iou_containment(curr_box, kept) > containment_thresh: is_nested = True break if not is_nested: final_boxes.append(curr_box) return final_boxes # ========================================== # 🔬 SCIENTIFIC LOGIC: PROJECTION PROFILES # ========================================== def split_double_lines(crop_img, logs): """ Analyzes a crop to see if it accidentally contains TWO lines of text. Includes 'Descender Protection' to prevent cutting off tails (y, g, p). """ # 1. Binarize gray = cv2.cvtColor(crop_img, cv2.COLOR_RGB2GRAY) _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # 2. Horizontal Projection h_proj = np.sum(thresh, axis=1) # Normalize max_val = np.max(h_proj) if max_val == 0: return [crop_img] h_proj = h_proj / max_val # 3. Find Peaks # Increased distance to 25 to prevent finding peaks within the same letter height peaks, _ = find_peaks(h_proj, height=0.2, distance=25) if len(peaks) < 2: return [crop_img] # 4. Analyze the Valley p1, p2 = peaks[0], peaks[1] valley_region = h_proj[p1:p2] if len(valley_region) == 0: return [crop_img] min_val = np.min(valley_region) min_idx = np.argmin(valley_region) + p1 # ========================== # 🛡️ SAFETY GUARDRAILS # ========================== # CHECK 1: Valley Depth # Handwriting lines often touch. If the valley isn't DEEP (very low ink), don't split. # We lowered the threshold to 0.15 (15% ink density) if min_val > 0.15: return [crop_img] # CHECK 2: Edge Protection (The "Descender" Check) # If the split point is too close to the bottom (e.g., > 75% down the image), # it's almost certainly chopping off tails (y, g, p, q), not separating a new line. total_height = crop_img.shape[0] split_ratio = min_idx / total_height if split_ratio < 0.20 or split_ratio > 0.75: logs.append(f" -> ⚠️ Refinement: Prevented split at {int(split_ratio*100)}% (Likely descenders)") return [crop_img] # If we pass checks, perform the split logs.append(f" -> ✂️ Refinement: Split double line at Y={min_idx}") top_crop = crop_img[0:min_idx, :] bot_crop = crop_img[min_idx:, :] return [top_crop, bot_crop] # ========================================== # ⛓️ PIPELINE STEP: MERGING & ORDERING # ========================================== # ========================================== # ⛓️ PIPELINE STEP: MERGING & ORDERING # ========================================== def smart_line_merger(raw_boxes, logs): """ Groups words into lines using Centroid Clustering & Vertical Overlap. """ # FIX: Use len() check because raw_boxes is a NumPy array, not a list if raw_boxes is None or len(raw_boxes) == 0: return [] # 1. Clean & Format rects = [] for box in raw_boxes: box = np.array(box).astype(np.float32) x1, y1 = np.min(box[:, 0]), np.min(box[:, 1]) x2, y2 = np.max(box[:, 0]), np.max(box[:, 1]) rects.append([x1, y1, x2, y2]) rects = filter_nested_boxes(rects) logs.append(f"Valid Word Boxes: {len(rects)}") # 2. Sort by Y-Center (approximate top-down) rects.sort(key=lambda r: (r[1] + r[3]) / 2) lines = [] while rects: # Start new line with the highest remaining box curr_line = [rects.pop(0)] # Find all other boxes that belong to this line remaining = [] for r in rects: # Check overlap against the *average* vertical span of the current line overlap = get_vertical_overlap_ratio(curr_line[0], r) # 0.4 means they share 40% of their vertical height if overlap > 0.4: curr_line.append(r) else: remaining.append(r) rects = remaining # Sort the collected line horizontally (Left to Right) curr_line.sort(key=lambda r: r[0]) # Merge coordinates lx1 = min(r[0] for r in curr_line) ly1 = min(r[1] for r in curr_line) lx2 = max(r[2] for r in curr_line) ly2 = max(r[3] for r in curr_line) lines.append([lx1, ly1, lx2, ly2]) # Final Sort of Lines (Top to Bottom) lines.sort(key=lambda r: r[1]) return lines # ========================================== # 🚀 MAIN EXECUTION # ========================================== def process_handwriting(image): logs = ["--- STARTING PIPELINE ---"] if image is None: return None, [], "Please upload an image.", "Error" # 1. PRE-PROCESS # Convert to RGB array orig_np = np.array(image.convert("RGB")) # 2. DETECT (PaddleOCR) try: dt_boxes, _ = detector.text_detector(orig_np) if dt_boxes is None: dt_boxes = [] except Exception as e: return image, [], f"Detector Failed: {e}", "\n".join(logs) if len(dt_boxes) == 0: return image, [], "No text detected.", "Logs end." # 3. MERGE WORDS -> LINES line_boxes = smart_line_merger(dt_boxes, logs) logs.append(f"Merged into {len(line_boxes)} lines.") # 4. RECOGNITION + REFINEMENT LOOP annotated_img = orig_np.copy() final_text_lines = [] gallery_crops = [] # Padding for crops (gives TrOCR context) PAD = 8 h_img, w_img, _ = orig_np.shape for i, box in enumerate(line_boxes): x1, y1, x2, y2 = map(int, box) # Add padding safely x1 = max(0, x1 - PAD); y1 = max(0, y1 - PAD) x2 = min(w_img, x2 + PAD); y2 = min(h_img, y2 + PAD) # Crop line_crop = orig_np[y1:y2, x1:x2] # --- REFINEMENT LOOP --- # Check if we accidentally merged two lines sub_crops = split_double_lines(line_crop, logs) for sub_crop in sub_crops: if sub_crop.shape[0] < 10 or sub_crop.shape[1] < 10: continue # Convert for TrOCR pil_crop = Image.fromarray(sub_crop) gallery_crops.append(pil_crop) # Inference with torch.no_grad(): pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(DEVICE) generated_ids = model.generate(pixel_values) text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if text.strip(): final_text_lines.append(text) # Visualization (Draw the *original* merged box in Green) cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (0, 200, 0), 2) cv2.putText(annotated_img, str(i+1), (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,200,0), 1) full_text = "\n".join(final_text_lines) logs.append("--- PROCESSING COMPLETE ---") return Image.fromarray(annotated_img), gallery_crops, full_text, "\n".join(logs) # ========================================== # 🖥️ GRADIO INTERFACE # ========================================== css = """ #gallery { height: 300px; overflow-y: scroll; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.Markdown("## 📝 Scientific Handwriting OCR (Line-Level Refinement)") gr.Markdown("Uses PaddleOCR for detection, Geometry for merging, Projection Profiles for refinement, and TrOCR for reading.") with gr.Row(): with gr.Column(scale=1): input_img = gr.Image(type="pil", label="Input Document") run_btn = gr.Button("Analyze & Transcribe", variant="primary") with gr.Column(scale=1): with gr.Tabs(): with gr.Tab("Transcribed Text"): output_txt = gr.Textbox(label="Result", lines=15, show_copy_button=True) with gr.Tab("Segmentation Map"): output_img = gr.Image(label="Line Detection Map") with gr.Tab("System Logs"): log_output = gr.Textbox(label="Process Logs", lines=15) gr.Markdown("### Line Segments (Input for TrOCR)") gallery = gr.Gallery(label="Refined Crops", columns=4, elem_id="gallery") run_btn.click( process_handwriting, input_img, [output_img, gallery, output_txt, log_output] ) if __name__ == "__main__": demo.launch()