import torch import string from PIL import Image from torchvision import transforms import io class EndpointHandler(): def __init__(self, path=""): # TorchScript model load ho raha hai self.model = torch.jit.load(f"{path}/master_brain_2026_final.pt") self.model.eval() self.chars = string.ascii_uppercase + string.ascii_lowercase + string.digits + "@#$%&=" self.blank_index = len(self.chars) self.preprocess = transforms.Compose([ transforms.Resize((35, 142)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) def __call__(self, data): inputs = data.pop("inputs", data) # Base64 ya raw bytes ko handle karne ke liye if isinstance(inputs, str): import base64 inputs = Image.open(io.BytesIO(base64.b64decode(inputs))).convert('RGB') else: inputs = Image.open(io.BytesIO(inputs)).convert('RGB') img_tensor = self.preprocess(inputs).unsqueeze(0) with torch.no_grad(): logits = self.model(img_tensor) max_indices = torch.argmax(logits, dim=2).squeeze() res = [] prev = -1 for idx in max_indices: val = idx.item() if val != self.blank_index and val != prev: res.append(self.chars[val]) prev = val return {"prediction": "".join(res).strip()}