File size: 1,139 Bytes
5db0083
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4e74be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json
import torch
from modeling import BigramLanguageModel

class EndpointHandler:
    def __init__(self, path=""):
        with open(f"{path}/config.json") as f:
            self.config = json.load(f)

        with open(f"{path}/vocab.json") as f:
            self.stoi = json.load(f)

        self.itos = {int(v): k for k, v in self.stoi.items()}

        self.model = BigramLanguageModel(self.config)
        self.model.load_state_dict(
            torch.load(f"{path}/pytorch_model.bin", map_location="cpu")
        )
        self.model.eval()

    def __call__(self, data):
        try:
            text = data["inputs"]
            max_new_tokens = min(int(data.get("max_new_tokens", 200)), 256)
    
            idx = torch.tensor(
                [[self.stoi.get(c, self.stoi[" "]) for c in text]],
                dtype=torch.long
            )
    
            out = self.model.generate(idx, max_new_tokens)[0]
    
            return {
                "outputs": "".join(self.itos[i.item()] for i in out)
            }
    
        except Exception as e:
            return {
                "error": str(e)
            }