Spaces:
Runtime error
Runtime error
| import torch | |
| import onnx | |
| import onnxruntime as rt | |
| from torchvision import transforms as T | |
| from pathlib import Path | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import gradio as gr | |
| from utils.tokenizer_base import Tokenizer | |
| # Download the model from Hugging Face Hub | |
| cwd = Path(__file__).parent.resolve() | |
| model_file = os.path.join(cwd, hf_hub_download("toandev/OCR-for-Captcha", "model.onnx")) | |
| # Define the image size and vocabulary | |
| img_size = (32, 128) | |
| vocab = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
| # Initialize the tokenizer | |
| tokenizer = Tokenizer(vocab) | |
| def to_numpy(tensor): | |
| """Convert tensor to numpy.""" | |
| return ( | |
| tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| ) | |
| def get_transform(img_size): | |
| """Preprocess the input image.""" | |
| transforms = [] | |
| transforms.extend( | |
| [ | |
| T.Resize(img_size, T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(0.5, 0.5), | |
| ] | |
| ) | |
| return T.Compose(transforms) | |
| def load_model(model_file): | |
| """Load the model and return the transform function.""" | |
| transform = get_transform(img_size) | |
| onnx_model = onnx.load(model_file) | |
| onnx.checker.check_model(onnx_model) | |
| s = rt.InferenceSession(model_file) | |
| return transform, s | |
| # Load the model | |
| transform, s = load_model(model_file=model_file) | |
| def process(img: Image.Image): | |
| """Predict the text from the input image.""" | |
| x = transform(img.convert("RGB")).unsqueeze(0) | |
| ort_inputs = {s.get_inputs()[0].name: to_numpy(x)} | |
| logits = s.run(None, ort_inputs)[0] | |
| probs = torch.tensor(logits).softmax(-1) | |
| preds, probs = tokenizer.decode(probs) | |
| return preds[0] | |
| iface = gr.Interface( | |
| process, | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Textbox(label="Predicted Text"), | |
| title="OCR for CAPTCHA", | |
| description="Solve captchas from images including letters and numbers, success rate is about 80-90%.", | |
| examples=[ | |
| "examples/1.png", | |
| "examples/2.jpg", | |
| "examples/3.jpg", | |
| "examples/4.png", | |
| "examples/5.png", | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |