| import base64 |
| import io |
| import sys |
| import torch |
| import onnx |
| import onnxruntime as rt |
| from torchvision import transforms as T |
| from tokenizer_base import Tokenizer |
| from PIL import Image |
| from huggingface_hub import hf_hub_download, try_to_load_from_cache |
|
|
|
|
| class DocumentParserModel: |
| def __init__(self): |
| charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" |
| img_size = (32, 128) |
|
|
| self.tokenizer_base = Tokenizer(charset) |
| self.transform = self.create_transform_pipeline(img_size) |
| self.ort_session = self.initialize_onnx_model() |
|
|
| def create_transform_pipeline(self, img_size): |
| transforms = [ |
| T.Resize(img_size, T.InterpolationMode.BICUBIC), |
| T.ToTensor(), |
| T.Normalize(0.5, 0.5), |
| ] |
| return T.Compose(transforms) |
|
|
| def initialize_onnx_model(self): |
| repo_id = "stevenchang/captcha" |
| filename = "captcha.onnx" |
|
|
| filepath = try_to_load_from_cache(repo_id, filename) |
|
|
| if isinstance(filepath, str): |
| model_file = filepath |
| else: |
| model_file = result = hf_hub_download(repo_id, filename) |
|
|
| onnx_model = onnx.load(model_file) |
| onnx.checker.check_model(onnx_model) |
| return rt.InferenceSession(model_file) |
|
|
| def load_image_from_base64(self, base64_string): |
| img_data = base64.b64decode(base64_string) |
| image_buffer = io.BytesIO(img_data) |
|
|
| try: |
| image = Image.open(image_buffer) |
| return image |
| except IOError: |
| print(f"Error: Cannot open image {image_blob}") |
| return None |
|
|
| def predict_text(self, image_blob): |
| with self.load_image_from_base64(image_blob) as img_org: |
| x = self.transform(img_org.convert("RGB")).unsqueeze(0) |
| ort_inputs = {self.ort_session.get_inputs()[0].name: x.cpu().numpy()} |
| logits = self.ort_session.run(None, ort_inputs)[0] |
| probs = torch.tensor(logits).softmax(-1) |
| preds, _ = self.tokenizer_base.decode(probs) |
| return preds[0] |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| doc_parser = DocumentParserModel() |
|
|
| if len(sys.argv) > 1: |
| image_blob = sys.argv[1] |
| result = doc_parser.predict_text(image_blob) |
| print(result) |
| else: |
| print("Please provide an image blob.") |
|
|