mmsocr / app.py
badru's picture
Update app.py
dbfe32a verified
import streamlit as st
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
import cv2
import numpy as np
import tempfile
# Load the processor and model
@st.cache_resource
def load_model():
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
return processor, model
processor, model = load_model()
# Helper function to preprocess the image and detect lines
def detect_lines(image, min_height=20, min_width=100):
# Convert the PIL image to a NumPy array
image_np = np.array(image)
# Convert to grayscale
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
# Apply binary thresholding
_, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Dilate to merge nearby text
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
dilated = cv2.dilate(binary, kernel, iterations=1)
# Find contours
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours top-to-bottom
bounding_boxes = [cv2.boundingRect(c) for c in contours]
bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) # Sort by y-coordinate
# Filter out small contours and merge nearby ones
filtered_boxes = []
for x, y, w, h in bounding_boxes:
if h >= min_height and w >= min_width: # Filter small boxes
filtered_boxes.append((x, y, w, h))
# Extract individual lines as images
line_images = []
for (x, y, w, h) in filtered_boxes:
line = image_np[y:y+h, x:x+w]
line_images.append(line)
return line_images
# Streamlit app
st.title("Multiline Handwritten OCR with Hugging Face and OpenCV")
uploaded_file = st.file_uploader("Upload an Image (JPG, JPEG, PNG)", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
# Load and display the uploaded image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
st.write("Processing the image...")
# Detect lines in the image
line_images = detect_lines(image, min_height=30, min_width=100)
st.write(f"Detected {len(line_images)} lines in the image.")
# Perform OCR on each detected line
extracted_text = ""
for idx, line_img in enumerate(line_images):
# Convert the line image to PIL format
line_pil = Image.fromarray(line_img)
# Prepare the image for OCR
pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
# Generate text from the line image
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Append the extracted text
extracted_text += f"Line {idx + 1}: {generated_text}\n"
# Display the extracted text
st.subheader("Extracted Text:")
st.text_area("Output Text", extracted_text, height=200)
# Provide an option to download the extracted text
st.download_button(
label="Download Text",
data=extracted_text,
file_name="extracted_text.txt",
mime="text/plain",
)
except Exception as e:
st.error(f"An error occurred while processing the image: {e}")
else:
st.info("Please upload an image to start the OCR process.")