File size: 2,439 Bytes
08f9ce5
 
60e65f6
08f9ce5
 
 
60e65f6
08f9ce5
 
007b6b2
41c4c09
 
60e65f6
 
08f9ce5
60e65f6
08f9ce5
9c6182d
 
 
 
 
 
08f9ce5
 
4477976
08f9ce5
 
 
 
 
 
6565480
 
 
 
08f9ce5
 
 
60e65f6
08f9ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e65f6
 
 
08f9ce5
 
 
60e65f6
4477976
ec432b9
60e65f6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import time
import torch
import numpy as np
import onnxruntime
from torchvision import transforms
from PIL import Image
from pathlib import Path
import re
#import editdistance
#from collections import Counter
#from functools import lru_cache
import gradio as gr

from huggingface_hub import hf_hub_download

# --- Configuration ---
hf_token = os.environ.get("HF_TOKEN")
ONNX_MODEL_PATH = hf_hub_download(
    repo_id="vanh99/GRU-model",
    filename="crnntiny_best.onnx",
    use_auth_token=hf_token
)
IMG_HEIGHT = 50
IMG_WIDTH = 160
CHARSET = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
IDX2CHAR = {i + 1: c for i, c in enumerate(CHARSET)}
BLANK_LABEL = 0

# --- Transform ---
def get_transform():
    return transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
    ])

# --- Load Image ---
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# --- Decode ---
def ctc_decode(output_np):
    pred_indices = np.argmax(output_np, axis=2)
    decoded_strings = []
    for indices in pred_indices:
        collapsed = [indices[0]] if len(indices) > 0 else []
        for i in range(1, len(indices)):
            if indices[i] != indices[i - 1]:
                collapsed.append(indices[i])
        final = [idx for idx in collapsed if idx != BLANK_LABEL]
        decoded_strings.append("".join([IDX2CHAR.get(idx, '?') for idx in final]))
    return decoded_strings

# --- Load model ---
def load_model():
    transform = get_transform()
    onnx_session = onnxruntime.InferenceSession(ONNX_MODEL_PATH)
    return transform, onnx_session

transform, session = load_model()
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# --- Predict ---
def predict_image(image):
    x = transform(image.convert("RGB")).unsqueeze(0)
    ort_inputs = {input_name: to_numpy(x)}
    logits = session.run([output_name], ort_inputs)[0]
    preds = ctc_decode(logits)
    return preds[0]

iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Textbox(label="Predicted Text"),
    title="OCR for CAPTCHA",
    description="Solve captchas from images.",
    examples=["1.png","2.jfif","3.jpg"]
)

if __name__ == "__main__":
    iface.launch()