| | import torch |
| | import string |
| | from PIL import Image |
| | from torchvision import transforms |
| | import io |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | self.model = torch.jit.load(f"{path}/master_brain_2026_final.pt") |
| | self.model.eval() |
| | self.chars = string.ascii_uppercase + string.ascii_lowercase + string.digits + "@#$%&=" |
| | self.blank_index = len(self.chars) |
| | self.preprocess = transforms.Compose([ |
| | transforms.Resize((35, 142)), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5,), (0.5,)) |
| | ]) |
| |
|
| | def __call__(self, data): |
| | inputs = data.pop("inputs", data) |
| | |
| | if isinstance(inputs, str): |
| | import base64 |
| | inputs = Image.open(io.BytesIO(base64.b64decode(inputs))).convert('RGB') |
| | else: |
| | inputs = Image.open(io.BytesIO(inputs)).convert('RGB') |
| | |
| | img_tensor = self.preprocess(inputs).unsqueeze(0) |
| | with torch.no_grad(): |
| | logits = self.model(img_tensor) |
| | max_indices = torch.argmax(logits, dim=2).squeeze() |
| | res = [] |
| | prev = -1 |
| | for idx in max_indices: |
| | val = idx.item() |
| | if val != self.blank_index and val != prev: |
| | res.append(self.chars[val]) |
| | prev = val |
| | return {"prediction": "".join(res).strip()} |
| |
|