|
|
|
|
|
"""
|
|
|
Created on Sun Feb 8 11:47:14 2026
|
|
|
|
|
|
@author: ottav
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from model import GPT
|
|
|
from tokenizers import Tokenizer
|
|
|
import os
|
|
|
|
|
|
class EndpointHandler():
|
|
|
def __init__(self, path=""):
|
|
|
|
|
|
self.tokenizer = Tokenizer.from_file(os.path.join(path, "tokenizer.json"))
|
|
|
self.bos_id = self.tokenizer.token_to_id("[BOS]")
|
|
|
self.eos_id = self.tokenizer.token_to_id("[EOS]")
|
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
self.model = GPT(vocab_size=14000).to(self.device)
|
|
|
|
|
|
|
|
|
weights_path = os.path.join(path, "pytorch_model.bin")
|
|
|
self.model.load_state_dict(torch.load(weights_path, map_location=self.device))
|
|
|
self.model.eval()
|
|
|
|
|
|
def __call__(self, data):
|
|
|
|
|
|
inputs = data.pop("inputs", data)
|
|
|
temp = data.pop("temperature", 0.7)
|
|
|
top_k = data.pop("top_k", 60)
|
|
|
top_p = data.pop("top_p", 0.85)
|
|
|
max_new_tokens = data.pop("max_new_tokens", 200)
|
|
|
|
|
|
|
|
|
prompt_text = f"Domanda: {inputs}\nRisposta: "
|
|
|
prompt_ids = [self.bos_id] + self.tokenizer.encode(prompt_text).ids
|
|
|
idx = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0).to(self.device)
|
|
|
|
|
|
generated_ids = []
|
|
|
for _ in range(max_new_tokens):
|
|
|
idx_cond = idx[:, -384:]
|
|
|
with torch.no_grad():
|
|
|
logits = self.model(idx_cond)
|
|
|
logits = logits[:, -1, :] / temp
|
|
|
|
|
|
|
|
|
v, _ = torch.topk(logits, top_k)
|
|
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
|
|
|
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
probs = F.softmax(sorted_logits, dim=-1)
|
|
|
cumulative_probs = torch.cumsum(probs, dim=-1)
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
sorted_indices_to_remove[..., 0] = False
|
|
|
logits[:, sorted_indices[sorted_indices_to_remove]] = -float("Inf")
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
if idx_next.item() == self.eos_id:
|
|
|
break
|
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1)
|
|
|
generated_ids.append(idx_next.item())
|
|
|
|
|
|
return self.tokenizer.decode(generated_ids) |