GPT-124m / scripts /generate.py
abhinavv3's picture
minor changes
498886e
import torch
import torch.nn.functional as F
import tiktoken
from model import GPT
def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'):
model.eval()
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(prompt)
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
xgen = tokens.to(device)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(42)
while xgen.size(1) < max_length:
with torch.no_grad():
logits, loss = model(xgen) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
xcol = torch.gather(topk_indices, -1, ix)
xgen = torch.cat((xgen, xcol), dim=1)
generated_texts = []
for i in range(num_return_sequences):
tokens = xgen[i, :max_length].tolist()
decoded = enc.decode(tokens)
generated_texts.append(decoded)
print(f"Sample {i + 1}: {decoded}")
return generated_texts
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"running with {device}")
checkpoint_path = 'log/model_final.pt'
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path,map_location=device)
model_config = checkpoint['config']
model_config.vocab_size = 50304
model = GPT(model_config)
model.load_state_dict(checkpoint['model'])
model.to(device)
prompt = "Hello, I'm a language model,"
generated_texts = generate_text(
model=model,
prompt=prompt,
num_return_sequences=4,
max_length=32,
device=device
)