Spaces:
Running
Running
| import torch | |
| import onnx | |
| import onnxruntime as rt | |
| from torchvision import transforms as T | |
| from pathlib import Path | |
| from PIL import Image | |
| from utils.tokenizer_base import Tokenizer | |
| import gradio as gr | |
| import io | |
| import base64 | |
| import os | |
| # ===================== | |
| # MODEL SETUP | |
| # ===================== | |
| model_file = Path(__file__).parent / "models/model.onnx" | |
| if not model_file.exists(): | |
| raise RuntimeError(f"Model not found at {model_file}") | |
| img_size = (32, 128) | |
| vocab = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
| tokenizer = Tokenizer(vocab) | |
| transform = T.Compose([ | |
| T.Resize(img_size, T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(0.5, 0.5), | |
| ]) | |
| session = rt.InferenceSession(str(model_file)) | |
| def to_numpy(t): | |
| return t.detach().cpu().numpy() | |
| def infer(img: Image.Image): | |
| x = transform(img.convert("RGB")).unsqueeze(0) | |
| logits = session.run(None, {session.get_inputs()[0].name: to_numpy(x)})[0] | |
| probs = torch.tensor(logits).softmax(-1) | |
| preds, _ = tokenizer.decode(probs) | |
| return preds[0] | |
| # ===================== | |
| # GRADIO FUNCTIONS | |
| # ===================== | |
| def predict_image(img): | |
| return infer(img) | |
| def predict_base64(b64: str): | |
| img_bytes = base64.b64decode(b64) | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| return infer(img) | |
| # ===================== | |
| # GRADIO APP (REQUIRED) | |
| # ===================== | |
| with gr.Blocks(title="Captcha OCR") as demo: | |
| gr.Markdown("# Captcha OCR") | |
| gr.Markdown("OCR for captcha images (letters & numbers)") | |
| with gr.Tab("Image Upload"): | |
| img = gr.Image(type="pil") | |
| out = gr.Textbox() | |
| gr.Button("Predict").click(predict_image, img, out) | |
| with gr.Tab("Base64 API"): | |
| b64 = gr.Textbox(label="Base64 Image") | |
| out2 = gr.Textbox() | |
| gr.Button("Predict").click(predict_base64, b64, out2) | |
| demo.queue() | |
| demo.launch() | |