crnn / app.py
Chung-Fan's picture
update app
9415ed0
import torch
from src.model import CRNN
from PIL import Image
import torchvision.transforms as transforms
import gradio as gr
import os
# ----------------------------
# 1️⃣ Load CRNN model
# ----------------------------
MODEL_PATH = "crnn_gpu.pt"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"{MODEL_PATH} not found! Make sure it's in the Space root.")
model = CRNN(img_height=32, img_width=100, img_channel=1, num_class=37, rnn_hidden=256)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
# ----------------------------
# 2️⃣ Characters and CTC decoding
# ----------------------------
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
def ctc_decode(preds):
"""Greedy CTC decoder"""
preds = preds.argmax(2).transpose(1,0).contiguous().view(-1)
decoded = []
prev_idx = -1
for idx in preds:
if idx != prev_idx and idx != 0: # skip duplicates & blank
decoded.append(alphabet[idx-1])
prev_idx = idx
return ''.join(decoded)
# ----------------------------
# 3️⃣ Preprocessing
# ----------------------------
def to_grayscale(img: Image.Image):
"""Convert any image type to grayscale"""
if img.mode != "L":
return img.convert("L")
return img
transform = transforms.Compose([
transforms.Lambda(to_grayscale), # convert any input image to grayscale
transforms.Resize((32, 100)), # match CRNN input
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# ----------------------------
# 4️⃣ OCR function
# ----------------------------
def ocr(image: Image.Image):
try:
img_tensor = transform(image).unsqueeze(0) # add batch dimension
with torch.no_grad():
preds = model(img_tensor)
text = ctc_decode(preds)
return text
except Exception as e:
return f"Error during inference: {e}"
# ----------------------------
# 5️⃣ Gradio interface
# ----------------------------
iface = gr.Interface(
fn=ocr,
inputs=gr.Image(type="pil", label="Upload any image (RGB, RGBA, etc.)"),
outputs="text",
title="CRNN OCR",
description="Upload an image and get the OCR text prediction."
)
# Launch
iface.launch(share=True)