| import json | |
| import torch | |
| from modeling import BigramLanguageModel | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| with open(f"{path}/config.json") as f: | |
| self.config = json.load(f) | |
| with open(f"{path}/vocab.json") as f: | |
| self.stoi = json.load(f) | |
| self.itos = {int(v): k for k, v in self.stoi.items()} | |
| self.model = BigramLanguageModel(self.config) | |
| self.model.load_state_dict( | |
| torch.load(f"{path}/pytorch_model.bin", map_location="cpu") | |
| ) | |
| self.model.eval() | |
| def __call__(self, data): | |
| try: | |
| text = data["inputs"] | |
| max_new_tokens = min(int(data.get("max_new_tokens", 200)), 256) | |
| idx = torch.tensor( | |
| [[self.stoi.get(c, self.stoi[" "]) for c in text]], | |
| dtype=torch.long | |
| ) | |
| out = self.model.generate(idx, max_new_tokens)[0] | |
| return { | |
| "outputs": "".join(self.itos[i.item()] for i in out) | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e) | |
| } |