FlavioRubensOttaviani's picture
Upload handler.py
b37982e verified
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 8 11:47:14 2026
@author: ottav
"""
import torch
import torch.nn.functional as F
from model import GPT
from tokenizers import Tokenizer
import os
class EndpointHandler():
def __init__(self, path=""):
# Carico tokenizer e ID speciali
self.tokenizer = Tokenizer.from_file(os.path.join(path, "tokenizer.json"))
self.bos_id = self.tokenizer.token_to_id("[BOS]")
self.eos_id = self.tokenizer.token_to_id("[EOS]")
# Setup device e inizializzo la mia classe GPT
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = GPT(vocab_size=14000).to(self.device)
# Carico i pesi dal file .bin
weights_path = os.path.join(path, "pytorch_model.bin")
self.model.load_state_dict(torch.load(weights_path, map_location=self.device))
self.model.eval()
def __call__(self, data):
# Estraggo i parametri o imposto i miei default
inputs = data.pop("inputs", data)
temp = data.pop("temperature", 0.7)
top_k = data.pop("top_k", 60)
top_p = data.pop("top_p", 0.85)
max_new_tokens = data.pop("max_new_tokens", 200)
# Formatto il prompt come ho fatto nel training
prompt_text = f"Domanda: {inputs}\nRisposta: "
prompt_ids = [self.bos_id] + self.tokenizer.encode(prompt_text).ids
idx = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0).to(self.device)
generated_ids = []
for _ in range(max_new_tokens):
idx_cond = idx[:, -384:]
with torch.no_grad():
logits = self.model(idx_cond)
logits = logits[:, -1, :] / temp
# Applico Top-K
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# Applico Top-P
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
logits[:, sorted_indices[sorted_indices_to_remove]] = -float("Inf")
# Campionamento
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Se becco il token di fine, mi fermo
if idx_next.item() == self.eos_id:
break
idx = torch.cat((idx, idx_next), dim=1)
generated_ids.append(idx_next.item())
return self.tokenizer.decode(generated_ids)