|
|
import torch |
|
|
from src.model import CRNN |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import gradio as gr |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "crnn_gpu.pt" |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
raise FileNotFoundError(f"{MODEL_PATH} not found! Make sure it's in the Space root.") |
|
|
|
|
|
model = CRNN(img_height=32, img_width=100, img_channel=1, num_class=37, rnn_hidden=256) |
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz' |
|
|
|
|
|
def ctc_decode(preds): |
|
|
"""Greedy CTC decoder""" |
|
|
preds = preds.argmax(2).transpose(1,0).contiguous().view(-1) |
|
|
decoded = [] |
|
|
prev_idx = -1 |
|
|
for idx in preds: |
|
|
if idx != prev_idx and idx != 0: |
|
|
decoded.append(alphabet[idx-1]) |
|
|
prev_idx = idx |
|
|
return ''.join(decoded) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_grayscale(img: Image.Image): |
|
|
"""Convert any image type to grayscale""" |
|
|
if img.mode != "L": |
|
|
return img.convert("L") |
|
|
return img |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Lambda(to_grayscale), |
|
|
transforms.Resize((32, 100)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5,), (0.5,)) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ocr(image: Image.Image): |
|
|
try: |
|
|
img_tensor = transform(image).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
preds = model(img_tensor) |
|
|
text = ctc_decode(preds) |
|
|
return text |
|
|
except Exception as e: |
|
|
return f"Error during inference: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=ocr, |
|
|
inputs=gr.Image(type="pil", label="Upload any image (RGB, RGBA, etc.)"), |
|
|
outputs="text", |
|
|
title="CRNN OCR", |
|
|
description="Upload an image and get the OCR text prediction." |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch(share=True) |
|
|
|