| import torch | |
| def generate_text(model, tokenizer, prompt, max_length=50, device='cuda'): | |
| model.eval() | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| with torch.no_grad(): | |
| for _ in range(max_length): | |
| outputs = model(input_ids) | |
| next_token_logits = outputs[:, -1, :] | |
| next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) | |
| input_ids = torch.cat([input_ids, next_token], dim=-1) | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |