File size: 2,289 Bytes
7818d5e
 
 
 
 
 
 
 
 
c18f087
 
7818d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c18f087
7818d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch.nn import functional as F
from transformers import PreTrainedTokenizer, AutoTokenizer

class LLMTrainer:
    def __init__(self,
                 model: torch.nn.Module = None,
                 tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
                 model_returns_logits: bool = False):
        self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_type)

        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained("gpt2")

        self.tokenizer = tokenizer
        self.model = model

        self.train_loader = None
        self.current_step: int = 0

        self.model_returns_logits = model_returns_logits

    def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None:
        self.model.to(self.device)
        self.model.eval()

        tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long)
        tokens = tokens.repeat(n_return_sequences, 1)

        generated_tokens = tokens.to(self.device)
        with torch.no_grad():
            while generated_tokens.size(1) < length:

                with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
                    if self.model_returns_logits:
                        logits = self.model(generated_tokens)
                    else:
                        logits = self.model(generated_tokens).logits

                logits = logits[:, -1, :]  # Get last token logits (B, vocab_size)
                probs = F.softmax(logits, dim=-1)  # Convert to probabilities

                topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1)
                sampled_indices = torch.multinomial(topk_probs, 1)  # Shape: (B, 1)
                next_tokens = torch.gather(topk_indices, -1, sampled_indices)  # (B, 1)

                generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1)

        continuations = []
        for i in range(n_return_sequences):
            tokens = generated_tokens[i, :length].tolist()
            decoded = self.tokenizer.decode(tokens)
            print(f"=== sample {i} ===\n{decoded}")
            continuations.append(decoded)
        return continuations