PrabhatGupta786's picture
Update app.py
47f1551 verified
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# 1. Setup
device = "cpu"
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
def get_lines_from_image(img_array):
# Gradio gives us a numpy array directly
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Binary threshold
_, thresh = cv2.threshold(gray, 180, 255, cv2.THRESH_BINARY_INV)
# Dilate horizontally to join letters into lines
kernel = np.ones((5, 100), np.uint8)
dilation = cv2.dilate(thresh, kernel, iterations=1)
# Find the boundaries of each line
contours, _ = cv2.findContours(dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours from top to bottom
contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
line_images = []
for ctr in contours:
x, y, w, h = cv2.boundingRect(ctr)
if h > 15 and w > 15:
# Add padding
y_start, y_end = max(0, y-10), min(img_array.shape[0], y+h+10)
x_start, x_end = max(0, x-10), min(img_array.shape[1], x+w+10)
roi = img_array[y_start:y_end, x_start:x_end]
line_images.append(Image.fromarray(roi)) # ROI is already RGB from Gradio
return line_images
def full_pipeline(input_img):
if input_img is None:
return "Please upload an image."
lines = get_lines_from_image(input_img)
if not lines:
return "No text lines detected. Try a clearer image."
final_transcript = []
for line_img in lines:
pixel_values = processor(images=line_img, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Cleanup
text = text.strip().rstrip('. ').strip()
final_transcript.append(text)
return " ".join(final_transcript)
# Gradio UI Setup
demo = gr.Interface(
fn=full_pipeline,
inputs=gr.Image(),
outputs="text",
title="Handwritten Paragraph to Typed Text",
description="Upload a handwritten paragraph."
)
if __name__ == "__main__":
demo.launch()