vanh99's picture
Update app.py
ec432b9 verified
import os
import time
import torch
import numpy as np
import onnxruntime
from torchvision import transforms
from PIL import Image
from pathlib import Path
import re
#import editdistance
#from collections import Counter
#from functools import lru_cache
import gradio as gr
from huggingface_hub import hf_hub_download
# --- Configuration ---
hf_token = os.environ.get("HF_TOKEN")
ONNX_MODEL_PATH = hf_hub_download(
repo_id="vanh99/GRU-model",
filename="crnntiny_best.onnx",
use_auth_token=hf_token
)
IMG_HEIGHT = 50
IMG_WIDTH = 160
CHARSET = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
IDX2CHAR = {i + 1: c for i, c in enumerate(CHARSET)}
BLANK_LABEL = 0
# --- Transform ---
def get_transform():
return transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# --- Load Image ---
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# --- Decode ---
def ctc_decode(output_np):
pred_indices = np.argmax(output_np, axis=2)
decoded_strings = []
for indices in pred_indices:
collapsed = [indices[0]] if len(indices) > 0 else []
for i in range(1, len(indices)):
if indices[i] != indices[i - 1]:
collapsed.append(indices[i])
final = [idx for idx in collapsed if idx != BLANK_LABEL]
decoded_strings.append("".join([IDX2CHAR.get(idx, '?') for idx in final]))
return decoded_strings
# --- Load model ---
def load_model():
transform = get_transform()
onnx_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH)
return transform, onnx_session
transform, session = load_model()
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# --- Predict ---
def predict_image(image):
x = transform(image.convert("RGB")).unsqueeze(0)
ort_inputs = {input_name: to_numpy(x)}
logits = session.run([output_name], ort_inputs)[0]
preds = ctc_decode(logits)
return preds[0]
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Textbox(label="Predicted Text"),
title="OCR for CAPTCHA",
description="Solve captchas from images.",
examples=["1.png","2.jfif","3.jpg"]
)
if __name__ == "__main__":
iface.launch()