| import gradio as gr |
| import logging |
| import os |
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
| |
| try: |
| from surya.detection import batch_text_detection |
| from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor |
| except ImportError: |
| from surya.detection import batch_inference as batch_text_detection |
| from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor |
|
|
| |
| |
| |
| device = "cpu" |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| logger.info("⏳ Loading Models...") |
|
|
| |
| det_processor = load_det_processor() |
| det_model = load_det_model().to(device) |
|
|
| |
| |
| |
| trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') |
| trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device) |
|
|
| logger.info("✅ All Models Loaded.") |
|
|
| |
| |
| |
| def recognize_batch(crops): |
| """ |
| Feeds raw crops directly to TrOCR. |
| """ |
| if not crops: return [] |
| |
| |
| valid_crops = [c for c in crops if c.size[0] > 0 and c.size[1] > 0] |
| if not valid_crops: return [] |
| |
| pixel_values = trocr_processor(images=valid_crops, return_tensors="pt").pixel_values.to(device) |
| |
| with torch.no_grad(): |
| |
| generated_ids = trocr_model.generate(pixel_values, max_length=64) |
| text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True) |
| return text |
|
|
| def draw_boxes(image, prediction_objects): |
| draw = ImageDraw.Draw(image) |
| for obj in prediction_objects: |
| if hasattr(obj, "bbox"): |
| draw.rectangle(obj.bbox, outline="red", width=2) |
| else: |
| |
| draw.rectangle(obj, outline="red", width=2) |
| return image |
|
|
| |
| |
| |
| def hybrid_ocr_workflow(image): |
| if image is None: return None, "Please upload an image." |
| |
| |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
| |
| |
| logger.info("Step 1: Detecting Lines with Surya...") |
| |
| predictions = batch_text_detection([image], det_model, det_processor) |
| result = predictions[0] |
| |
| |
| lines_objects = [] |
| if hasattr(result, "bboxes"): |
| lines_objects = result.bboxes |
| elif hasattr(result, "text_lines"): |
| lines_objects = result.text_lines |
| |
| |
| lines_objects.sort(key=lambda x: x.bbox[1]) |
| |
| |
| logger.info(f"Step 2: Recognizing {len(lines_objects)} lines with TrOCR...") |
| |
| line_crops = [] |
| w, h = image.size |
| |
| for obj in lines_objects: |
| bbox = obj.bbox |
| |
| |
| pad = 6 |
| x1 = max(0, int(bbox[0]) - pad) |
| y1 = max(0, int(bbox[1]) - pad) |
| x2 = min(w, int(bbox[2]) + pad) |
| y2 = min(h, int(bbox[3]) + pad) |
| |
| line_crop = image.crop((x1, y1, x2, y2)) |
| line_crops.append(line_crop) |
| |
| |
| full_text_lines = [] |
| batch_size = 4 |
| |
| for i in range(0, len(line_crops), batch_size): |
| batch = line_crops[i:i+batch_size] |
| try: |
| batch_results = recognize_batch(batch) |
| full_text_lines.extend(batch_results) |
| except Exception as e: |
| logger.error(f"Batch failed: {e}") |
| full_text_lines.append("[Error processing line]") |
|
|
| final_text = "\n".join(full_text_lines) |
| |
| |
| vis_img = draw_boxes(image.copy(), lines_objects) |
| |
| return vis_img, final_text |
|
|
| |
| |
| |
| custom_css = """ |
| .gen-button { background-color: #ff4081 !important; color: white !important; font-weight: bold !important; } |
| """ |
|
|
| with gr.Blocks(css=custom_css) as demo: |
| gr.Markdown("# 🚀 Hybrid OCR: Surya (Raw) + TrOCR (Corrected)") |
| |
| with gr.Row(): |
| ocr_input = gr.Image(type="pil", label="Upload Image") |
| ocr_output_img = gr.Image(type="pil", label="Surya Detections") |
| |
| ocr_text = gr.Textbox(label="Recognized Text", lines=20) |
| ocr_button = gr.Button("Run Hybrid OCR", elem_classes="gen-button") |
| |
| ocr_button.click(hybrid_ocr_workflow, inputs=[ocr_input], outputs=[ocr_output_img, ocr_text]) |
|
|
| if __name__ == "__main__": |
| demo.launch(theme=gr.themes.Soft(), css=custom_css) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|