|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import tiktoken |
|
|
from model import GPT |
|
|
|
|
|
def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'): |
|
|
model.eval() |
|
|
enc = tiktoken.get_encoding('gpt2') |
|
|
tokens = enc.encode(prompt) |
|
|
tokens = torch.tensor(tokens, dtype=torch.long) |
|
|
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) |
|
|
xgen = tokens.to(device) |
|
|
sample_rng = torch.Generator(device=device) |
|
|
sample_rng.manual_seed(42) |
|
|
|
|
|
while xgen.size(1) < max_length: |
|
|
with torch.no_grad(): |
|
|
logits, loss = model(xgen) |
|
|
logits = logits[:, -1, :] |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
|
|
ix = torch.multinomial(topk_probs, 1, generator=sample_rng) |
|
|
xcol = torch.gather(topk_indices, -1, ix) |
|
|
xgen = torch.cat((xgen, xcol), dim=1) |
|
|
|
|
|
generated_texts = [] |
|
|
for i in range(num_return_sequences): |
|
|
tokens = xgen[i, :max_length].tolist() |
|
|
decoded = enc.decode(tokens) |
|
|
generated_texts.append(decoded) |
|
|
print(f"Sample {i + 1}: {decoded}") |
|
|
|
|
|
|
|
|
return generated_texts |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
print(f"running with {device}") |
|
|
|
|
|
|
|
|
checkpoint_path = 'log/model_final.pt' |
|
|
|
|
|
print(f"Loading checkpoint from {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path,map_location=device) |
|
|
model_config = checkpoint['config'] |
|
|
model_config.vocab_size = 50304 |
|
|
model = GPT(model_config) |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
prompt = "Hello, I'm a language model," |
|
|
|
|
|
generated_texts = generate_text( |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
num_return_sequences=4, |
|
|
max_length=32, |
|
|
device=device |
|
|
) |
|
|
|