File size: 1,968 Bytes
ccfb646 498886e ccfb646 498886e ccfb646 498886e ccfb646 498886e ccfb646 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
import torch.nn.functional as F
import tiktoken
from model import GPT
def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'):
model.eval()
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prompt)
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
xgen = tokens.to(device)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(42)
while xgen.size(1) < max_length:
with torch.no_grad():
logits, loss = model(xgen) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
xcol = torch.gather(topk_indices, -1, ix)
xgen = torch.cat((xgen, xcol), dim=1)
generated_texts = []
for i in range(num_return_sequences):
tokens = xgen[i, :max_length].tolist()
decoded = enc.decode(tokens)
generated_texts.append(decoded)
print(f"Sample {i + 1}: {decoded}")
return generated_texts
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"running with {device}")
checkpoint_path = 'log/model_final.pt'
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path,map_location=device)
model_config = checkpoint['config']
model_config.vocab_size = 50304
model = GPT(model_config)
model.load_state_dict(checkpoint['model'])
model.to(device)
prompt = "Hello, I'm a language model,"
generated_texts = generate_text(
model=model,
prompt=prompt,
num_return_sequences=4,
max_length=32,
device=device
)
|