File size: 1,456 Bytes
d26d01f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import torch
from typing import Optional
import torch.nn.functional as F
def load_model(checkpoint_path, model):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.eval()
return model
def generate_text(
model,
data_processor,
prompt: str,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
device: str = "cpu",
):
model.eval()
tokens = data_processor.tokenize(prompt)
input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
for _ in range(max_new_tokens):
# crop input_ids if it exceeds the context size
if input_ids.size(1) > model.config.max_token_len:
input_ids = input_ids[:, -model.config.max_token_len :]
logits = model(input_ids)
logits = logits[:, -1, :] / temperature # get the logits for the last token
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=1)
output_tokens = input_ids[0].tolist()
generated_text = data_processor.detokenize(output_tokens)
return generated_text
|