Spaces:
Paused
Paused
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| import onnxruntime | |
| from torchvision import transforms | |
| from PIL import Image | |
| from pathlib import Path | |
| import re | |
| #import editdistance | |
| #from collections import Counter | |
| #from functools import lru_cache | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| # --- Configuration --- | |
| hf_token = os.environ.get("HF_TOKEN") | |
| ONNX_MODEL_PATH = hf_hub_download( | |
| repo_id="vanh99/GRU-model", | |
| filename="crnntiny_best.onnx", | |
| use_auth_token=hf_token | |
| ) | |
| IMG_HEIGHT = 50 | |
| IMG_WIDTH = 160 | |
| CHARSET = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
| IDX2CHAR = {i + 1: c for i, c in enumerate(CHARSET)} | |
| BLANK_LABEL = 0 | |
| # --- Transform --- | |
| def get_transform(): | |
| return transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| # --- Load Image --- | |
| def to_numpy(tensor): | |
| return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| # --- Decode --- | |
| def ctc_decode(output_np): | |
| pred_indices = np.argmax(output_np, axis=2) | |
| decoded_strings = [] | |
| for indices in pred_indices: | |
| collapsed = [indices[0]] if len(indices) > 0 else [] | |
| for i in range(1, len(indices)): | |
| if indices[i] != indices[i - 1]: | |
| collapsed.append(indices[i]) | |
| final = [idx for idx in collapsed if idx != BLANK_LABEL] | |
| decoded_strings.append("".join([IDX2CHAR.get(idx, '?') for idx in final])) | |
| return decoded_strings | |
| # --- Load model --- | |
| def load_model(): | |
| transform = get_transform() | |
| onnx_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH) | |
| return transform, onnx_session | |
| transform, session = load_model() | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| # --- Predict --- | |
| def predict_image(image): | |
| x = transform(image.convert("RGB")).unsqueeze(0) | |
| ort_inputs = {input_name: to_numpy(x)} | |
| logits = session.run([output_name], ort_inputs)[0] | |
| preds = ctc_decode(logits) | |
| return preds[0] | |
| iface = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil", label="Input Image"), | |
| outputs=gr.Textbox(label="Predicted Text"), | |
| title="OCR for CAPTCHA", | |
| description="Solve captchas from images.", | |
| examples=["1.png","2.jfif","3.jpg"] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |