| import streamlit as st |
| import torch |
| from PIL import Image |
| import pytesseract |
| from torchvision import transforms |
| from model import UTRNet |
|
|
| |
| def load_model(): |
| model = UTRNet() |
| model.load_state_dict(torch.load('saved_models/UTRNet-Large/best_norm_ED.pth')) |
| model.eval() |
| return model |
|
|
| |
| def preprocess_image(image): |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Resize((320, 320)), |
| ]) |
| return transform(image).unsqueeze(0) |
|
|
| |
| def predict_ocr(image, model): |
| image_tensor = preprocess_image(image) |
| with torch.no_grad(): |
| output = model(image_tensor) |
| |
| return output |
|
|
| |
| def main(): |
| st.title("Urdu Text Extraction Using UTRNet") |
| st.write("Upload an image containing Urdu text for OCR extraction.") |
| |
| uploaded_image = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"]) |
| |
| if uploaded_image is not None: |
| |
| image = Image.open(uploaded_image) |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
| |
| model = load_model() |
|
|
| |
| if st.button("Extract Text"): |
| output = predict_ocr(image, model) |
| st.write("Extracted Text:") |
| st.write(output) |
|
|
| if __name__ == "__main__": |
| main() |
|
|