|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image, ImageOps |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
from paddleocr import PaddleOCR |
|
|
from scipy.signal import find_peaks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("--- SYSTEM STARTUP ---") |
|
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
|
print(f"-> Hardware Device: {DEVICE}") |
|
|
|
|
|
|
|
|
|
|
|
print("-> Loading TrOCR Model...") |
|
|
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') |
|
|
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(DEVICE).eval() |
|
|
|
|
|
|
|
|
|
|
|
print("-> Loading PaddleOCR Detector...") |
|
|
detector = PaddleOCR( |
|
|
use_angle_cls=True, |
|
|
lang='en', |
|
|
show_log=False, |
|
|
use_gpu=False, |
|
|
det_limit_side_len=2500, |
|
|
det_db_thresh=0.1, |
|
|
det_db_box_thresh=0.3, |
|
|
det_db_unclip_ratio=1.6 |
|
|
) |
|
|
print("--- SYSTEMS READY ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_iou_containment(box1, box2): |
|
|
""" |
|
|
Calculates how much of box1 is inside box2. |
|
|
""" |
|
|
x1 = max(box1[0], box2[0]) |
|
|
y1 = max(box1[1], box2[1]) |
|
|
x2 = min(box1[2], box2[2]) |
|
|
y2 = min(box1[3], box2[3]) |
|
|
|
|
|
if x2 < x1 or y2 < y1: return 0.0 |
|
|
|
|
|
intersection = (x2 - x1) * (y2 - y1) |
|
|
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + 1e-6 |
|
|
return intersection / area1 |
|
|
|
|
|
def get_vertical_overlap_ratio(box1, box2): |
|
|
""" |
|
|
Calculates vertical overlap between two boxes. |
|
|
Used to determine if words are on the same line. |
|
|
""" |
|
|
|
|
|
y1_a, y2_a = box1[1], box1[3] |
|
|
y1_b, y2_b = box2[1], box2[3] |
|
|
|
|
|
intersection_start = max(y1_a, y1_b) |
|
|
intersection_end = min(y2_a, y2_b) |
|
|
|
|
|
if intersection_end < intersection_start: return 0.0 |
|
|
|
|
|
overlap_height = intersection_end - intersection_start |
|
|
min_height = min(y2_a - y1_a, y2_b - y1_b) + 1e-6 |
|
|
|
|
|
return overlap_height / min_height |
|
|
|
|
|
def filter_nested_boxes(boxes, containment_thresh=0.9): |
|
|
""" |
|
|
Removes small noise boxes inside larger real boxes. |
|
|
""" |
|
|
if not boxes: return [] |
|
|
|
|
|
|
|
|
active = [] |
|
|
for b in boxes: |
|
|
area = (b[2] - b[0]) * (b[3] - b[1]) |
|
|
active.append(list(b) + [area]) |
|
|
|
|
|
|
|
|
active.sort(key=lambda x: x[4], reverse=True) |
|
|
|
|
|
final_boxes = [] |
|
|
for current in active: |
|
|
is_nested = False |
|
|
curr_box = current[:4] |
|
|
|
|
|
for kept in final_boxes: |
|
|
if calculate_iou_containment(curr_box, kept) > containment_thresh: |
|
|
is_nested = True |
|
|
break |
|
|
|
|
|
if not is_nested: |
|
|
final_boxes.append(curr_box) |
|
|
|
|
|
return final_boxes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_double_lines(crop_img, logs): |
|
|
""" |
|
|
Analyzes a crop to see if it accidentally contains TWO lines of text. |
|
|
Includes 'Descender Protection' to prevent cutting off tails (y, g, p). |
|
|
""" |
|
|
|
|
|
gray = cv2.cvtColor(crop_img, cv2.COLOR_RGB2GRAY) |
|
|
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
|
|
|
|
|
|
|
|
h_proj = np.sum(thresh, axis=1) |
|
|
|
|
|
|
|
|
max_val = np.max(h_proj) |
|
|
if max_val == 0: return [crop_img] |
|
|
h_proj = h_proj / max_val |
|
|
|
|
|
|
|
|
|
|
|
peaks, _ = find_peaks(h_proj, height=0.2, distance=25) |
|
|
|
|
|
if len(peaks) < 2: |
|
|
return [crop_img] |
|
|
|
|
|
|
|
|
p1, p2 = peaks[0], peaks[1] |
|
|
valley_region = h_proj[p1:p2] |
|
|
|
|
|
if len(valley_region) == 0: return [crop_img] |
|
|
|
|
|
min_val = np.min(valley_region) |
|
|
min_idx = np.argmin(valley_region) + p1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if min_val > 0.15: |
|
|
return [crop_img] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_height = crop_img.shape[0] |
|
|
split_ratio = min_idx / total_height |
|
|
|
|
|
if split_ratio < 0.20 or split_ratio > 0.75: |
|
|
logs.append(f" -> ⚠️ Refinement: Prevented split at {int(split_ratio*100)}% (Likely descenders)") |
|
|
return [crop_img] |
|
|
|
|
|
|
|
|
logs.append(f" -> ✂️ Refinement: Split double line at Y={min_idx}") |
|
|
top_crop = crop_img[0:min_idx, :] |
|
|
bot_crop = crop_img[min_idx:, :] |
|
|
|
|
|
return [top_crop, bot_crop] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smart_line_merger(raw_boxes, logs): |
|
|
""" |
|
|
Groups words into lines using Centroid Clustering & Vertical Overlap. |
|
|
""" |
|
|
|
|
|
if raw_boxes is None or len(raw_boxes) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
rects = [] |
|
|
for box in raw_boxes: |
|
|
box = np.array(box).astype(np.float32) |
|
|
x1, y1 = np.min(box[:, 0]), np.min(box[:, 1]) |
|
|
x2, y2 = np.max(box[:, 0]), np.max(box[:, 1]) |
|
|
rects.append([x1, y1, x2, y2]) |
|
|
|
|
|
rects = filter_nested_boxes(rects) |
|
|
logs.append(f"Valid Word Boxes: {len(rects)}") |
|
|
|
|
|
|
|
|
rects.sort(key=lambda r: (r[1] + r[3]) / 2) |
|
|
|
|
|
lines = [] |
|
|
|
|
|
while rects: |
|
|
|
|
|
curr_line = [rects.pop(0)] |
|
|
|
|
|
|
|
|
remaining = [] |
|
|
for r in rects: |
|
|
|
|
|
overlap = get_vertical_overlap_ratio(curr_line[0], r) |
|
|
|
|
|
|
|
|
if overlap > 0.4: |
|
|
curr_line.append(r) |
|
|
else: |
|
|
remaining.append(r) |
|
|
|
|
|
rects = remaining |
|
|
|
|
|
|
|
|
curr_line.sort(key=lambda r: r[0]) |
|
|
|
|
|
|
|
|
lx1 = min(r[0] for r in curr_line) |
|
|
ly1 = min(r[1] for r in curr_line) |
|
|
lx2 = max(r[2] for r in curr_line) |
|
|
ly2 = max(r[3] for r in curr_line) |
|
|
|
|
|
lines.append([lx1, ly1, lx2, ly2]) |
|
|
|
|
|
|
|
|
lines.sort(key=lambda r: r[1]) |
|
|
return lines |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_handwriting(image): |
|
|
logs = ["--- STARTING PIPELINE ---"] |
|
|
|
|
|
if image is None: return None, [], "Please upload an image.", "Error" |
|
|
|
|
|
|
|
|
|
|
|
orig_np = np.array(image.convert("RGB")) |
|
|
|
|
|
|
|
|
try: |
|
|
dt_boxes, _ = detector.text_detector(orig_np) |
|
|
if dt_boxes is None: dt_boxes = [] |
|
|
except Exception as e: |
|
|
return image, [], f"Detector Failed: {e}", "\n".join(logs) |
|
|
|
|
|
if len(dt_boxes) == 0: |
|
|
return image, [], "No text detected.", "Logs end." |
|
|
|
|
|
|
|
|
line_boxes = smart_line_merger(dt_boxes, logs) |
|
|
logs.append(f"Merged into {len(line_boxes)} lines.") |
|
|
|
|
|
|
|
|
annotated_img = orig_np.copy() |
|
|
final_text_lines = [] |
|
|
gallery_crops = [] |
|
|
|
|
|
|
|
|
PAD = 8 |
|
|
h_img, w_img, _ = orig_np.shape |
|
|
|
|
|
for i, box in enumerate(line_boxes): |
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
|
|
|
|
|
|
x1 = max(0, x1 - PAD); y1 = max(0, y1 - PAD) |
|
|
x2 = min(w_img, x2 + PAD); y2 = min(h_img, y2 + PAD) |
|
|
|
|
|
|
|
|
line_crop = orig_np[y1:y2, x1:x2] |
|
|
|
|
|
|
|
|
|
|
|
sub_crops = split_double_lines(line_crop, logs) |
|
|
|
|
|
for sub_crop in sub_crops: |
|
|
if sub_crop.shape[0] < 10 or sub_crop.shape[1] < 10: continue |
|
|
|
|
|
|
|
|
pil_crop = Image.fromarray(sub_crop) |
|
|
gallery_crops.append(pil_crop) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pixel_values = processor(images=pil_crop, return_tensors="pt").pixel_values.to(DEVICE) |
|
|
generated_ids = model.generate(pixel_values) |
|
|
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
if text.strip(): |
|
|
final_text_lines.append(text) |
|
|
|
|
|
|
|
|
cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (0, 200, 0), 2) |
|
|
cv2.putText(annotated_img, str(i+1), (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,200,0), 1) |
|
|
|
|
|
full_text = "\n".join(final_text_lines) |
|
|
logs.append("--- PROCESSING COMPLETE ---") |
|
|
|
|
|
return Image.fromarray(annotated_img), gallery_crops, full_text, "\n".join(logs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
|
#gallery { height: 300px; overflow-y: scroll; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
|
|
gr.Markdown("## 📝 Scientific Handwriting OCR (Line-Level Refinement)") |
|
|
gr.Markdown("Uses PaddleOCR for detection, Geometry for merging, Projection Profiles for refinement, and TrOCR for reading.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_img = gr.Image(type="pil", label="Input Document") |
|
|
run_btn = gr.Button("Analyze & Transcribe", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Transcribed Text"): |
|
|
output_txt = gr.Textbox(label="Result", lines=15, show_copy_button=True) |
|
|
with gr.Tab("Segmentation Map"): |
|
|
output_img = gr.Image(label="Line Detection Map") |
|
|
with gr.Tab("System Logs"): |
|
|
log_output = gr.Textbox(label="Process Logs", lines=15) |
|
|
|
|
|
gr.Markdown("### Line Segments (Input for TrOCR)") |
|
|
gallery = gr.Gallery(label="Refined Crops", columns=4, elem_id="gallery") |
|
|
|
|
|
run_btn.click( |
|
|
process_handwriting, |
|
|
input_img, |
|
|
[output_img, gallery, output_txt, log_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |