import torch from typing import Optional import torch.nn.functional as F def load_model(checkpoint_path, model): checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) model.eval() return model def generate_text( model, data_processor, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, device: str = "cpu", ): model.eval() tokens = data_processor.tokenize(prompt) input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) with torch.no_grad(): for _ in range(max_new_tokens): # crop input_ids if it exceeds the context size if input_ids.size(1) > model.config.max_token_len: input_ids = input_ids[:, -model.config.max_token_len :] logits = model(input_ids) logits = logits[:, -1, :] / temperature # get the logits for the last token if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat((input_ids, next_token), dim=1) output_tokens = input_ids[0].tolist() generated_text = data_processor.detokenize(output_tokens) return generated_text