File size: 3,025 Bytes
b37982e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
"""

Created on Sun Feb  8 11:47:14 2026



@author: ottav

"""

import torch
import torch.nn.functional as F
from model import GPT
from tokenizers import Tokenizer
import os

class EndpointHandler():
    def __init__(self, path=""):
        # Carico tokenizer e ID speciali
        self.tokenizer = Tokenizer.from_file(os.path.join(path, "tokenizer.json"))
        self.bos_id = self.tokenizer.token_to_id("[BOS]")
        self.eos_id = self.tokenizer.token_to_id("[EOS]")
        
        # Setup device e inizializzo la mia classe GPT
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = GPT(vocab_size=14000).to(self.device)
        
        # Carico i pesi dal file .bin
        weights_path = os.path.join(path, "pytorch_model.bin")
        self.model.load_state_dict(torch.load(weights_path, map_location=self.device))
        self.model.eval()

    def __call__(self, data):
        # Estraggo i parametri o imposto i miei default
        inputs = data.pop("inputs", data)
        temp = data.pop("temperature", 0.7)
        top_k = data.pop("top_k", 60)
        top_p = data.pop("top_p", 0.85)
        max_new_tokens = data.pop("max_new_tokens", 200)

        # Formatto il prompt come ho fatto nel training
        prompt_text = f"Domanda: {inputs}\nRisposta: "
        prompt_ids = [self.bos_id] + self.tokenizer.encode(prompt_text).ids
        idx = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0).to(self.device)

        generated_ids = []
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -384:] 
            with torch.no_grad():
                logits = self.model(idx_cond)
                logits = logits[:, -1, :] / temp
                
                # Applico Top-K
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
                
                # Applico Top-P
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                probs = F.softmax(sorted_logits, dim=-1)
                cumulative_probs = torch.cumsum(probs, dim=-1)
                
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = False
                logits[:, sorted_indices[sorted_indices_to_remove]] = -float("Inf")
                
                # Campionamento
                probs = F.softmax(logits, dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
                
                # Se becco il token di fine, mi fermo
                if idx_next.item() == self.eos_id:
                    break
                
                idx = torch.cat((idx, idx_next), dim=1)
                generated_ids.append(idx_next.item())

        return self.tokenizer.decode(generated_ids)