deepseek / app.py
iammraat's picture
Update app.py
8d8f4ec verified
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image, ImageOps
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from paddleocr import PaddleOCR
from scipy.signal import find_peaks
# ==========================================
# ⚙️ CONFIGURATION & MODEL LOADING
# ==========================================
print("--- SYSTEM STARTUP ---")
# Force CPU to avoid CUDA overhead on CPU-only Spaces
DEVICE = "cpu"
print(f"-> Hardware Device: {DEVICE}")
# 1. LOAD TR-OCR (Recognition)
# We use the 'stage1' model which is often more robust for general handwriting
print("-> Loading TrOCR Model...")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(DEVICE).eval()
# 2. LOAD PADDLEOCR (Detection)
# 'structure_version' and generic settings tuned for recall (catch everything, filter later)
print("-> Loading PaddleOCR Detector...")
detector = PaddleOCR(
use_angle_cls=True,
lang='en',
show_log=False,
use_gpu=False,
det_limit_side_len=2500, # High res for small text
det_db_thresh=0.1, # Low threshold to catch faint ink
det_db_box_thresh=0.3,
det_db_unclip_ratio=1.6
)
print("--- SYSTEMS READY ---")
# ==========================================
# 🧠 CORE LOGIC: GEOMETRY UTILS
# ==========================================
def calculate_iou_containment(box1, box2):
"""
Calculates how much of box1 is inside box2.
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
if x2 < x1 or y2 < y1: return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + 1e-6
return intersection / area1
def get_vertical_overlap_ratio(box1, box2):
"""
Calculates vertical overlap between two boxes.
Used to determine if words are on the same line.
"""
# y1, y2 are top, bottom
y1_a, y2_a = box1[1], box1[3]
y1_b, y2_b = box2[1], box2[3]
intersection_start = max(y1_a, y1_b)
intersection_end = min(y2_a, y2_b)
if intersection_end < intersection_start: return 0.0
overlap_height = intersection_end - intersection_start
min_height = min(y2_a - y1_a, y2_b - y1_b) + 1e-6
return overlap_height / min_height
def filter_nested_boxes(boxes, containment_thresh=0.9):
"""
Removes small noise boxes inside larger real boxes.
"""
if not boxes: return []
# Add area to list: [x1, y1, x2, y2, area]
active = []
for b in boxes:
area = (b[2] - b[0]) * (b[3] - b[1])
active.append(list(b) + [area])
# Sort largest first
active.sort(key=lambda x: x[4], reverse=True)
final_boxes = []
for current in active:
is_nested = False
curr_box = current[:4]
for kept in final_boxes:
if calculate_iou_containment(curr_box, kept) > containment_thresh:
is_nested = True
break
if not is_nested:
final_boxes.append(curr_box)
return final_boxes
# ==========================================
# 🔬 SCIENTIFIC LOGIC: PROJECTION PROFILES
# ==========================================
def split_double_lines(crop_img, logs):
"""
Analyzes a crop to see if it accidentally contains TWO lines of text.
Includes 'Descender Protection' to prevent cutting off tails (y, g, p).
"""
# 1. Binarize
gray = cv2.cvtColor(crop_img, cv2.COLOR_RGB2GRAY)
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# 2. Horizontal Projection
h_proj = np.sum(thresh, axis=1)
# Normalize
max_val = np.max(h_proj)
if max_val == 0: return [crop_img]
h_proj = h_proj / max_val
# 3. Find Peaks
# Increased distance to 25 to prevent finding peaks within the same letter height
peaks, _ = find_peaks(h_proj, height=0.2, distance=25)
if len(peaks) < 2:
return [crop_img]
# 4. Analyze the Valley
p1, p2 = peaks[0], peaks[1]
valley_region = h_proj[p1:p2]
if len(valley_region) == 0: return [crop_img]
min_val = np.min(valley_region)
min_idx = np.argmin(valley_region) + p1
# ==========================
# 🛡️ SAFETY GUARDRAILS
# ==========================
# CHECK 1: Valley Depth
# Handwriting lines often touch. If the valley isn't DEEP (very low ink), don't split.
# We lowered the threshold to 0.15 (15% ink density)
if min_val > 0.15:
return [crop_img]
# CHECK 2: Edge Protection (The "Descender" Check)
# If the split point is too close to the bottom (e.g., > 75% down the image),
# it's almost certainly chopping off tails (y, g, p, q), not separating a new line.
total_height = crop_img.shape[0]
split_ratio = min_idx / total_height
if split_ratio < 0.20 or split_ratio > 0.75:
logs.append(f" -> ⚠️ Refinement: Prevented split at {int(split_ratio*100)}% (Likely descenders)")
return [crop_img]
# If we pass checks, perform the split
logs.append(f" -> ✂️ Refinement: Split double line at Y={min_idx}")
top_crop = crop_img[0:min_idx, :]
bot_crop = crop_img[min_idx:, :]
return [top_crop, bot_crop]
# ==========================================
# ⛓️ PIPELINE STEP: MERGING & ORDERING
# ==========================================
# ==========================================
# ⛓️ PIPELINE STEP: MERGING & ORDERING
# ==========================================
def smart_line_merger(raw_boxes, logs):
"""
Groups words into lines using Centroid Clustering & Vertical Overlap.
"""
# FIX: Use len() check because raw_boxes is a NumPy array, not a list
if raw_boxes is None or len(raw_boxes) == 0:
return []
# 1. Clean & Format
rects = []
for box in raw_boxes:
box = np.array(box).astype(np.float32)
x1, y1 = np.min(box[:, 0]), np.min(box[:, 1])
x2, y2 = np.max(box[:, 0]), np.max(box[:, 1])
rects.append([x1, y1, x2, y2])
rects = filter_nested_boxes(rects)
logs.append(f"Valid Word Boxes: {len(rects)}")
# 2. Sort by Y-Center (approximate top-down)
rects.sort(key=lambda r: (r[1] + r[3]) / 2)
lines = []
while rects:
# Start new line with the highest remaining box
curr_line = [rects.pop(0)]
# Find all other boxes that belong to this line
remaining = []
for r in rects:
# Check overlap against the *average* vertical span of the current line
overlap = get_vertical_overlap_ratio(curr_line[0], r)
# 0.4 means they share 40% of their vertical height
if overlap > 0.4:
curr_line.append(r)
else:
remaining.append(r)
rects = remaining
# Sort the collected line horizontally (Left to Right)
curr_line.sort(key=lambda r: r[0])
# Merge coordinates
lx1 = min(r[0] for r in curr_line)
ly1 = min(r[1] for r in curr_line)
lx2 = max(r[2] for r in curr_line)
ly2 = max(r[3] for r in curr_line)
lines.append([lx1, ly1, lx2, ly2])
# Final Sort of Lines (Top to Bottom)
lines.sort(key=lambda r: r[1])
return lines
# ==========================================
# 🚀 MAIN EXECUTION
# ==========================================
def process_handwriting(image):
logs = ["--- STARTING PIPELINE ---"]
if image is None: return None, [], "Please upload an image.", "Error"
# 1. PRE-PROCESS
# Convert to RGB array
orig_np = np.array(image.convert("RGB"))
# 2. DETECT (PaddleOCR)
try:
dt_boxes, _ = detector.text_detector(orig_np)
if dt_boxes is None: dt_boxes = []
except Exception as e:
return image, [], f"Detector Failed: {e}", "\n".join(logs)
if len(dt_boxes) == 0:
return image, [], "No text detected.", "Logs end."
# 3. MERGE WORDS -> LINES
line_boxes = smart_line_merger(dt_boxes, logs)
logs.append(f"Merged into {len(line_boxes)} lines.")
# 4. RECOGNITION + REFINEMENT LOOP
annotated_img = orig_np.copy()
final_text_lines = []
gallery_crops = []
# Padding for crops (gives TrOCR context)
PAD = 8
h_img, w_img, _ = orig_np.shape
for i, box in enumerate(line_boxes):
x1, y1, x2, y2 = map(int, box)
# Add padding safely
x1 = max(0, x1 - PAD); y1 = max(0, y1 - PAD)
x2 = min(w_img, x2 + PAD); y2 = min(h_img, y2 + PAD)
# Crop
line_crop = orig_np[y1:y2, x1:x2]
# --- REFINEMENT LOOP ---
# Check if we accidentally merged two lines
sub_crops = split_double_lines(line_crop, logs)
for sub_crop in sub_crops:
if sub_crop.shape[0] < 10 or sub_crop.shape[1] < 10: continue
# Convert for TrOCR
pil_crop = Image.fromarray(sub_crop)
gallery_crops.append(pil_crop)
# Inference
with torch.no_grad():
pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(DEVICE)
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if text.strip():
final_text_lines.append(text)
# Visualization (Draw the *original* merged box in Green)
cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (0, 200, 0), 2)
cv2.putText(annotated_img, str(i+1), (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,200,0), 1)
full_text = "\n".join(final_text_lines)
logs.append("--- PROCESSING COMPLETE ---")
return Image.fromarray(annotated_img), gallery_crops, full_text, "\n".join(logs)
# ==========================================
# 🖥️ GRADIO INTERFACE
# ==========================================
css = """
#gallery { height: 300px; overflow-y: scroll; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("## 📝 Scientific Handwriting OCR (Line-Level Refinement)")
gr.Markdown("Uses PaddleOCR for detection, Geometry for merging, Projection Profiles for refinement, and TrOCR for reading.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Input Document")
run_btn = gr.Button("Analyze & Transcribe", variant="primary")
with gr.Column(scale=1):
with gr.Tabs():
with gr.Tab("Transcribed Text"):
output_txt = gr.Textbox(label="Result", lines=15, show_copy_button=True)
with gr.Tab("Segmentation Map"):
output_img = gr.Image(label="Line Detection Map")
with gr.Tab("System Logs"):
log_output = gr.Textbox(label="Process Logs", lines=15)
gr.Markdown("### Line Segments (Input for TrOCR)")
gallery = gr.Gallery(label="Refined Crops", columns=4, elem_id="gallery")
run_btn.click(
process_handwriting,
input_img,
[output_img, gallery, output_txt, log_output]
)
if __name__ == "__main__":
demo.launch()