import torch from src.model import CRNN from PIL import Image import torchvision.transforms as transforms import gradio as gr import os # ---------------------------- # 1️⃣ Load CRNN model # ---------------------------- MODEL_PATH = "crnn_gpu.pt" if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"{MODEL_PATH} not found! Make sure it's in the Space root.") model = CRNN(img_height=32, img_width=100, img_channel=1, num_class=37, rnn_hidden=256) model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) model.eval() # ---------------------------- # 2️⃣ Characters and CTC decoding # ---------------------------- alphabet = '0123456789abcdefghijklmnopqrstuvwxyz' def ctc_decode(preds): """Greedy CTC decoder""" preds = preds.argmax(2).transpose(1,0).contiguous().view(-1) decoded = [] prev_idx = -1 for idx in preds: if idx != prev_idx and idx != 0: # skip duplicates & blank decoded.append(alphabet[idx-1]) prev_idx = idx return ''.join(decoded) # ---------------------------- # 3️⃣ Preprocessing # ---------------------------- def to_grayscale(img: Image.Image): """Convert any image type to grayscale""" if img.mode != "L": return img.convert("L") return img transform = transforms.Compose([ transforms.Lambda(to_grayscale), # convert any input image to grayscale transforms.Resize((32, 100)), # match CRNN input transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # ---------------------------- # 4️⃣ OCR function # ---------------------------- def ocr(image: Image.Image): try: img_tensor = transform(image).unsqueeze(0) # add batch dimension with torch.no_grad(): preds = model(img_tensor) text = ctc_decode(preds) return text except Exception as e: return f"Error during inference: {e}" # ---------------------------- # 5️⃣ Gradio interface # ---------------------------- iface = gr.Interface( fn=ocr, inputs=gr.Image(type="pil", label="Upload any image (RGB, RGBA, etc.)"), outputs="text", title="CRNN OCR", description="Upload an image and get the OCR text prediction." ) # Launch iface.launch(share=True)