surya / app.py
iammraat's picture
Update app.py
0ea7492 verified
import gradio as gr
import logging
import os
import numpy as np
import torch
from PIL import Image, ImageDraw
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# --- SURYA IMPORTS ---
try:
from surya.detection import batch_text_detection
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
except ImportError:
from surya.detection import batch_inference as batch_text_detection
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
# ==========================================
# 1. SETUP MODELS
# ==========================================
device = "cpu"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("⏳ Loading Models...")
# A. SURYA DETECTION
det_processor = load_det_processor()
det_model = load_det_model().to(device)
# B. TROCR RECOGNITION
# NOTE: We do NOT use quantization here. It destroys the attention mechanism in ViT
# encoders on CPU, leading to "mode collapse" (hallucinations).
trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
logger.info("✅ All Models Loaded.")
# ==========================================
# 2. HELPER FUNCTIONS
# ==========================================
def recognize_batch(crops):
"""
Feeds raw crops directly to TrOCR.
"""
if not crops: return []
# Ensure crops are valid
valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0]
if not valid_crops: return []
pixel_values = trocr_processor(images=valid_crops, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
# Using a slightly lower max_length prevents it from rambling if it gets confused
generated_ids = trocr_model.generate(pixel_values, max_length=64)
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
return text
def draw_boxes(image, prediction_objects):
draw = ImageDraw.Draw(image)
for obj in prediction_objects:
if hasattr(obj, "bbox"):
draw.rectangle(obj.bbox, outline="red", width=2)
else:
# Fallback if obj is just a list/tuple
draw.rectangle(obj, outline="red", width=2)
return image
# ==========================================
# 3. MAIN WORKFLOW
# ==========================================
def hybrid_ocr_workflow(image):
if image is None: return None, "Please upload an image."
# CRITICAL FIX: Ensure image is RGB (TrOCR fails on RGBA/P modes silently)
if image.mode != "RGB":
image = image.convert("RGB")
# 1. DETECT (Surya)
logger.info("Step 1: Detecting Lines with Surya...")
# Surya expects list of images
predictions = batch_text_detection([image], det_model, det_processor)
result = predictions[0]
# Extract BBoxes
lines_objects = []
if hasattr(result, "bboxes"):
lines_objects = result.bboxes
elif hasattr(result, "text_lines"):
lines_objects = result.text_lines
# Sort by Y-coordinate (top to bottom)
lines_objects.sort(key=lambda x: x.bbox[1])
# 2. CROP & RECOGNIZE
logger.info(f"Step 2: Recognizing {len(lines_objects)} lines with TrOCR...")
line_crops = []
w, h = image.size
for obj in lines_objects:
bbox = obj.bbox
# Crop the full line
pad = 6
x1 = max(0, int(bbox[0]) - pad)
y1 = max(0, int(bbox[1]) - pad)
x2 = min(w, int(bbox[2]) + pad)
y2 = min(h, int(bbox[3]) + pad)
line_crop = image.crop((x1, y1, x2, y2))
line_crops.append(line_crop)
# Batch processing
full_text_lines = []
batch_size = 4
for i in range(0, len(line_crops), batch_size):
batch = line_crops[i:i+batch_size]
try:
batch_results = recognize_batch(batch)
full_text_lines.extend(batch_results)
except Exception as e:
logger.error(f"Batch failed: {e}")
full_text_lines.append("[Error processing line]")
final_text = "\n".join(full_text_lines)
# Visualize
vis_img = draw_boxes(image.copy(), lines_objects)
return vis_img, final_text
# ==========================================
# 4. GRADIO UI
# ==========================================
custom_css = """
.gen-button { background-color: #ff4081 !important; color: white !important; font-weight: bold !important; }
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# 🚀 Hybrid OCR: Surya (Raw) + TrOCR (Corrected)")
with gr.Row():
ocr_input = gr.Image(type="pil", label="Upload Image")
ocr_output_img = gr.Image(type="pil", label="Surya Detections")
ocr_text = gr.Textbox(label="Recognized Text", lines=20)
ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button")
ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text])
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(), css=custom_css)