import os import time import torch import numpy as np import onnxruntime from torchvision import transforms from PIL import Image from pathlib import Path import re #import editdistance #from collections import Counter #from functools import lru_cache import gradio as gr from huggingface_hub import hf_hub_download # --- Configuration --- hf_token = os.environ.get("HF_TOKEN") ONNX_MODEL_PATH = hf_hub_download( repo_id="vanh99/GRU-model", filename="crnntiny_best.onnx", use_auth_token=hf_token ) IMG_HEIGHT = 50 IMG_WIDTH = 160 CHARSET = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" IDX2CHAR = {i + 1: c for i, c in enumerate(CHARSET)} BLANK_LABEL = 0 # --- Transform --- def get_transform(): return transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # --- Load Image --- def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() # --- Decode --- def ctc_decode(output_np): pred_indices = np.argmax(output_np, axis=2) decoded_strings = [] for indices in pred_indices: collapsed = [indices[0]] if len(indices) > 0 else [] for i in range(1, len(indices)): if indices[i] != indices[i - 1]: collapsed.append(indices[i]) final = [idx for idx in collapsed if idx != BLANK_LABEL] decoded_strings.append("".join([IDX2CHAR.get(idx, '?') for idx in final])) return decoded_strings # --- Load model --- def load_model(): transform = get_transform() onnx_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH) return transform, onnx_session transform, session = load_model() input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name # --- Predict --- def predict_image(image): x = transform(image.convert("RGB")).unsqueeze(0) ort_inputs = {input_name: to_numpy(x)} logits = session.run([output_name], ort_inputs)[0] preds = ctc_decode(logits) return preds[0] iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Input Image"), outputs=gr.Textbox(label="Predicted Text"), title="OCR for CAPTCHA", description="Solve captchas from images.", examples=["1.png","2.jfif","3.jpg"] ) if __name__ == "__main__": iface.launch()