AureliusGPT / handler.py
Tarush-AI's picture
Upload handler.py with huggingface_hub
aa18873 verified
import os, sys, re
from typing import Dict, List, Any, Union
import torch
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
from model.model import Transformer
from model.vocab.tokenizer import Tokenizer
import config
class EndpointHandler:
def __init__(self, path: str = ""):
self.base_dir = path or REPO_ROOT
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model loading from file
ckpt_path = os.path.join(self.base_dir, "epoch_10.pt")
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Missing checkpoint at: {ckpt_path}")
self.model = Transformer().to(self.device)
ckpt = torch.load(ckpt_path, map_location=self.device)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
elif isinstance(ckpt, dict) and "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
else:
state_dict = ckpt
self.model.load_state_dict(state_dict, strict=True)
self.model.eval()
#tokenizer loading from file
token_path = os.path.join(self.base_dir, "tokenizer.model")
if not os.path.isfile(token_path):
raise FileNotFoundError(f"Missing tokenizer weights at: {token_path}")
self.tokenizer = Tokenizer()
self.tokenizer.load_weights(token_path)
def _last_token_logits(self, model_out: torch.Tensor) -> torch.Tensor:
if model_out.dim() == 3:
return model_out[0, -1, :]
if model_out.dim() == 2:
return model_out[-1, :]
raise ValueError(f"Unexpected model output shape: {tuple(model_out.shape)}")
@torch.inference_mode()
def _generate_one(self, prompt: str) -> str:
encoded = torch.as_tensor(
self.tokenizer.encode(prompt),
dtype=torch.long,
device=self.device,
)
if encoded.numel() == 0:
return "AURELIUS: (No input processed)"
currtoken = ""
outputstring = ""
countcheck = 0
while currtoken != "<END>" and countcheck < config.max_tokens:
logits = self._last_token_logits(self.model(encoded))
if config.argmax:
next_id = int(torch.argmax(logits).item())
else:
probs = torch.softmax(logits / config.temperature, dim=-1)
next_id = int(torch.multinomial(probs, num_samples=1).item())
currtoken = self.tokenizer.decode([next_id]).strip()
if re.match(r"^[.,!?;:]", currtoken):
if outputstring.endswith(" "):
outputstring = outputstring[:-1]
outputstring += currtoken + " "
else:
outputstring += currtoken + " "
encoded = torch.cat(
[encoded, torch.tensor([next_id], dtype=torch.long, device=self.device)],
dim=0,
)
if encoded.numel() > config.max_seq_length:
encoded = encoded[-config.max_seq_length :]
countcheck += 1
text = re.sub("<BEGIN>", "\n\n", outputstring)
text = re.sub("<END>", "\n\n", text)
return "AURELIUS: " + text
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", data)
if isinstance(inputs, dict):
inputs = inputs.get("text", "")
if isinstance(inputs, list):
return [{"generated_text": self._generate_one(str(x))} for x in inputs]
return [{"generated_text": self._generate_one(str(inputs))}]