Tarush-AI commited on
Commit
aa18873
·
verified ·
1 Parent(s): 758a810

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +109 -0
handler.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, re
2
+ from typing import Dict, List, Any, Union
3
+ import torch
4
+
5
+ REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
6
+ if REPO_ROOT not in sys.path:
7
+ sys.path.insert(0, REPO_ROOT)
8
+
9
+ from model.model import Transformer
10
+ from model.vocab.tokenizer import Tokenizer
11
+ import config
12
+
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path: str = ""):
16
+ self.base_dir = path or REPO_ROOT
17
+
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ #model loading from file
21
+ ckpt_path = os.path.join(self.base_dir, "epoch_10.pt")
22
+ if not os.path.isfile(ckpt_path):
23
+ raise FileNotFoundError(f"Missing checkpoint at: {ckpt_path}")
24
+
25
+ self.model = Transformer().to(self.device)
26
+
27
+ ckpt = torch.load(ckpt_path, map_location=self.device)
28
+
29
+ if isinstance(ckpt, dict) and "state_dict" in ckpt:
30
+ state_dict = ckpt["state_dict"]
31
+ elif isinstance(ckpt, dict) and "model_state_dict" in ckpt:
32
+ state_dict = ckpt["model_state_dict"]
33
+ else:
34
+ state_dict = ckpt
35
+
36
+ self.model.load_state_dict(state_dict, strict=True)
37
+ self.model.eval()
38
+
39
+ #tokenizer loading from file
40
+ token_path = os.path.join(self.base_dir, "tokenizer.model")
41
+ if not os.path.isfile(token_path):
42
+ raise FileNotFoundError(f"Missing tokenizer weights at: {token_path}")
43
+
44
+ self.tokenizer = Tokenizer()
45
+ self.tokenizer.load_weights(token_path)
46
+
47
+ def _last_token_logits(self, model_out: torch.Tensor) -> torch.Tensor:
48
+ if model_out.dim() == 3:
49
+ return model_out[0, -1, :]
50
+ if model_out.dim() == 2:
51
+ return model_out[-1, :]
52
+ raise ValueError(f"Unexpected model output shape: {tuple(model_out.shape)}")
53
+
54
+ @torch.inference_mode()
55
+ def _generate_one(self, prompt: str) -> str:
56
+ encoded = torch.as_tensor(
57
+ self.tokenizer.encode(prompt),
58
+ dtype=torch.long,
59
+ device=self.device,
60
+ )
61
+
62
+ if encoded.numel() == 0:
63
+ return "AURELIUS: (No input processed)"
64
+
65
+ currtoken = ""
66
+ outputstring = ""
67
+ countcheck = 0
68
+
69
+ while currtoken != "<END>" and countcheck < config.max_tokens:
70
+ logits = self._last_token_logits(self.model(encoded))
71
+
72
+ if config.argmax:
73
+ next_id = int(torch.argmax(logits).item())
74
+ else:
75
+ probs = torch.softmax(logits / config.temperature, dim=-1)
76
+ next_id = int(torch.multinomial(probs, num_samples=1).item())
77
+
78
+ currtoken = self.tokenizer.decode([next_id]).strip()
79
+
80
+ if re.match(r"^[.,!?;:]", currtoken):
81
+ if outputstring.endswith(" "):
82
+ outputstring = outputstring[:-1]
83
+ outputstring += currtoken + " "
84
+ else:
85
+ outputstring += currtoken + " "
86
+
87
+ encoded = torch.cat(
88
+ [encoded, torch.tensor([next_id], dtype=torch.long, device=self.device)],
89
+ dim=0,
90
+ )
91
+ if encoded.numel() > config.max_seq_length:
92
+ encoded = encoded[-config.max_seq_length :]
93
+
94
+ countcheck += 1
95
+
96
+ text = re.sub("<BEGIN>", "\n\n", outputstring)
97
+ text = re.sub("<END>", "\n\n", text)
98
+ return "AURELIUS: " + text
99
+
100
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
101
+ inputs = data.get("inputs", data)
102
+
103
+ if isinstance(inputs, dict):
104
+ inputs = inputs.get("text", "")
105
+
106
+ if isinstance(inputs, list):
107
+ return [{"generated_text": self._generate_one(str(x))} for x in inputs]
108
+
109
+ return [{"generated_text": self._generate_one(str(inputs))}]