gpt2 / src /engine /generate.py
triton329's picture
Upload folder using huggingface_hub
c21e887 verified
Raw
History Blame Contribute Delete
3.24 kB
import torch
import torch.nn.functional as F
from src.data.tokenizer import BaseTokenizer
def generate_greedy(model, tokenizer: BaseTokenizer, device, prompt, num_tokens=20):
""" Greedy generation - take the token with the highest probability """
model.eval()
ids = tokenizer.encode(prompt)
with torch.no_grad():
for _ in range(num_tokens):
input_ids = torch.tensor([ids[-1023:]]).to(device) # pick the last 1023 tokens from the prompt
logits = model(input_ids) # forward pass through the model to get logits
ids.append(logits[0, -1, :].argmax().item()) # the last row of the logits is the logits for the next token
return tokenizer.decode(ids)
def generate_temperature(model, tokenizer: BaseTokenizer, device, prompt, num_tokens=20, temperature=0.7):
""" Temperature sampling - sample from the distribution with a given temperature. """
model.eval()
ids = tokenizer.encode(prompt)
with torch.no_grad():
for _ in range(num_tokens):
input_ids = torch.tensor([ids[-1023:]]).to(device)
logits = model(input_ids)
next_logits = logits[0, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
ids.append(torch.multinomial(probs, 1).item())
return tokenizer.decode(ids)
def generate_top_k(model, tokenizer: BaseTokenizer, device, prompt, num_tokens=20, k=50, temperature=0.7):
""" Top-k sampling - sample from the k most likely tokens. """
model.eval()
ids = tokenizer.encode(prompt)
with torch.no_grad():
for _ in range(num_tokens):
input_ids = torch.tensor([ids[-1023:]]).to(device)
logits = model(input_ids)
next_logits = logits[0, -1, :] / temperature
topk_logits, topk_indices = torch.topk(next_logits, k) # Keep only k
topk_probs = torch.softmax(topk_logits, dim=-1)
sampled_idx = torch.multinomial(topk_probs, 1).item()
ids.append(topk_indices[sampled_idx].item())
return tokenizer.decode(ids)
def generate_top_p(model, tokenizer: BaseTokenizer, device, prompt, num_tokens=20, p=0.9, temperature=0.7):
""" Top-p sampling - sample from the smallest set of tokens whose cumulative probability exceeds p. """
model.eval()
ids = tokenizer.encode(prompt)
with torch.no_grad():
for _ in range(num_tokens):
input_ids = torch.tensor([ids[-1023:]]).to(device)
logits = model(input_ids)
next_logits = logits[0, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs > p
# shift right so the token that crosses the threshold is kept
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
sorted_probs = sorted_probs / sorted_probs.sum()
sampled_idx = torch.multinomial(sorted_probs, 1).item()
ids.append(sorted_indices[sampled_idx].item())
return tokenizer.decode(ids)