import streamlit as st from PIL import Image, ImageOps, ImageEnhance import requests from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch import numpy as np # Load pre-trained TrOCR model and processor processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") # Function to perform OCR and extract text def extract_text_from_image(image): # Preprocess the image image = preprocess_image(image) # Convert image to tensor pixel_values = processor(images=image, return_tensors="pt").pixel_values # Ensure the model is in evaluation mode model.eval() # Perform OCR with torch.no_grad(): generated_ids = model.generate(pixel_values) # Decode the generated IDs to text text = processor.decode(generated_ids[0], skip_special_tokens=True) return text # Image preprocessing function to enhance OCR performance def preprocess_image(image): # Convert image to grayscale image = image.convert("RGB") image = ImageOps.grayscale(image) # Enhance the image for better contrast (optional) enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(2.0) # Enhance contrast # Apply thresholding to make text more distinct (optional) image = image.point(lambda p: p > 200 and 255) # Simple thresholding # Resize the image to a size appropriate for OCR image = image.resize((384, 384)) # Adjust this size as needed return image # Streamlit UI st.title("OCR Text Extraction from Image") st.write(""" Upload an image containing text, and this app will extract and display the text from the image using the powerful TrOCR model! """) # File uploader to upload the image uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Open the uploaded image image = Image.open(uploaded_file) # Display the uploaded image st.image(image, caption="Uploaded Image", use_column_width=True) # Button to extract text if st.button("Extract Text"): with st.spinner('Extracting text...'): extracted_text = extract_text_from_image(image) st.subheader("Extracted Text:") st.write(extracted_text)