# -*- coding: utf-8 -*- """ Created on Sun Feb 8 11:47:14 2026 @author: ottav """ import torch import torch.nn.functional as F from model import GPT from tokenizers import Tokenizer import os class EndpointHandler(): def __init__(self, path=""): # Carico tokenizer e ID speciali self.tokenizer = Tokenizer.from_file(os.path.join(path, "tokenizer.json")) self.bos_id = self.tokenizer.token_to_id("[BOS]") self.eos_id = self.tokenizer.token_to_id("[EOS]") # Setup device e inizializzo la mia classe GPT self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = GPT(vocab_size=14000).to(self.device) # Carico i pesi dal file .bin weights_path = os.path.join(path, "pytorch_model.bin") self.model.load_state_dict(torch.load(weights_path, map_location=self.device)) self.model.eval() def __call__(self, data): # Estraggo i parametri o imposto i miei default inputs = data.pop("inputs", data) temp = data.pop("temperature", 0.7) top_k = data.pop("top_k", 60) top_p = data.pop("top_p", 0.85) max_new_tokens = data.pop("max_new_tokens", 200) # Formatto il prompt come ho fatto nel training prompt_text = f"Domanda: {inputs}\nRisposta: " prompt_ids = [self.bos_id] + self.tokenizer.encode(prompt_text).ids idx = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0).to(self.device) generated_ids = [] for _ in range(max_new_tokens): idx_cond = idx[:, -384:] with torch.no_grad(): logits = self.model(idx_cond) logits = logits[:, -1, :] / temp # Applico Top-K v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') # Applico Top-P sorted_logits, sorted_indices = torch.sort(logits, descending=True) probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False logits[:, sorted_indices[sorted_indices_to_remove]] = -float("Inf") # Campionamento probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # Se becco il token di fine, mi fermo if idx_next.item() == self.eos_id: break idx = torch.cat((idx, idx_next), dim=1) generated_ids.append(idx_next.item()) return self.tokenizer.decode(generated_ids)