File size: 1,912 Bytes
aacc8da
 
 
 
 
 
962c956
24f7ea5
962c956
 
24f7ea5
aacc8da
24f7ea5
962c956
24f7ea5
 
 
 
aacc8da
 
 
 
 
24f7ea5
 
 
 
 
aacc8da
24f7ea5
aacc8da
 
24f7ea5
 
aacc8da
 
24f7ea5
aacc8da
24f7ea5
aacc8da
962c956
aacc8da
 
24f7ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import onnx
import onnxruntime as rt
from torchvision import transforms as T
from pathlib import Path
from PIL import Image
from utils.tokenizer_base import Tokenizer
import gradio as gr
import io
import base64
import os

# =====================
# MODEL SETUP
# =====================
model_file = Path(__file__).parent / "models/model.onnx"
if not model_file.exists():
    raise RuntimeError(f"Model not found at {model_file}")

img_size = (32, 128)
vocab = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
tokenizer = Tokenizer(vocab)

transform = T.Compose([
    T.Resize(img_size, T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(0.5, 0.5),
])

session = rt.InferenceSession(str(model_file))


def to_numpy(t):
    return t.detach().cpu().numpy()


def infer(img: Image.Image):
    x = transform(img.convert("RGB")).unsqueeze(0)
    logits = session.run(None, {session.get_inputs()[0].name: to_numpy(x)})[0]
    probs = torch.tensor(logits).softmax(-1)
    preds, _ = tokenizer.decode(probs)
    return preds[0]


# =====================
# GRADIO FUNCTIONS
# =====================
def predict_image(img):
    return infer(img)


def predict_base64(b64: str):
    img_bytes = base64.b64decode(b64)
    img = Image.open(io.BytesIO(img_bytes))
    return infer(img)


# =====================
# GRADIO APP (REQUIRED)
# =====================
with gr.Blocks(title="Captcha OCR") as demo:
    gr.Markdown("# Captcha OCR")
    gr.Markdown("OCR for captcha images (letters & numbers)")

    with gr.Tab("Image Upload"):
        img = gr.Image(type="pil")
        out = gr.Textbox()
        gr.Button("Predict").click(predict_image, img, out)

    with gr.Tab("Base64 API"):
        b64 = gr.Textbox(label="Base64 Image")
        out2 = gr.Textbox()
        gr.Button("Predict").click(predict_base64, b64, out2)

demo.queue()
demo.launch()