File size: 1,968 Bytes
ccfb646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498886e
 
 
 
 
ccfb646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498886e
ccfb646
 
 
 
 
498886e
ccfb646
 
498886e
ccfb646
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn.functional as F
import tiktoken
from model import GPT

def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'):
        model.eval()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(prompt)
        tokens = torch.tensor(tokens, dtype=torch.long)
        tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
        xgen = tokens.to(device)
        sample_rng = torch.Generator(device=device)
        sample_rng.manual_seed(42)
        
        while xgen.size(1) < max_length:
            with torch.no_grad():
                logits, loss = model(xgen)  # (B, T, vocab_size)
                logits = logits[:, -1, :]  # (B, vocab_size)
                probs = F.softmax(logits, dim=-1)  
                topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)  
                ix = torch.multinomial(topk_probs, 1, generator=sample_rng)  
                xcol = torch.gather(topk_indices, -1, ix)  
                xgen = torch.cat((xgen, xcol), dim=1)  
        
        generated_texts = []
        for i in range(num_return_sequences):
            tokens = xgen[i, :max_length].tolist()
            decoded = enc.decode(tokens)
            generated_texts.append(decoded)
            print(f"Sample {i + 1}: {decoded}")

        
        return generated_texts
    

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"running with {device}")


checkpoint_path = 'log/model_final.pt'

print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path,map_location=device)
model_config = checkpoint['config']
model_config.vocab_size = 50304 
model = GPT(model_config)


model.load_state_dict(checkpoint['model'])
model.to(device)



prompt = "Hello, I'm a language model,"

generated_texts = generate_text(
        model=model,
        prompt=prompt,
        num_return_sequences=4,
        max_length=32,
        device=device
    )