| import torch |
| from typing import Callable |
|
|
|
|
| def sample_next_token(model, context: torch.Tensor, device, temperature: float = 1.0) -> torch.Tensor: |
| """Sample the next token from the model given a context sequence.""" |
| assert context.ndim == 2, "context should be (batch_size, seq_len)" |
| model.eval() |
| with torch.no_grad(): |
| context = context.to(device) |
| logits = model(context)[:, :, -1] |
| scaled_logits = logits / temperature |
| probs = torch.softmax(scaled_logits, dim=-1) |
| return torch.multinomial(probs, num_samples=1) |
|
|
|
|
| def generate_sequence( |
| model, |
| prompt: torch.Tensor, |
| max_len: int, |
| device, |
| temperature: float = 1.0, |
| include_prompt: bool = False, |
| ) -> torch.Tensor: |
| """Autoregressively generate a sequence of tokens from a prompt.""" |
| assert prompt.ndim == 2, "prompt should be (batch_size, seq_len)" |
| context = prompt.to(device) |
| for _ in range(max_len): |
| next_token = sample_next_token(model, context, device, temperature=temperature) |
| context = torch.concat([context, next_token], dim=-1) |
| return context[0, len(prompt):] if not include_prompt else context[0, :] |
|
|
|
|
| def generate_text( |
| model, |
| prompt: str, |
| text_encoder: Callable[[str], torch.Tensor], |
| text_decoder: Callable[[torch.Tensor], str], |
| device, |
| max_len: int = 128, |
| temperature: float = 1.0, |
| include_prompt: bool = True, |
| ) -> str: |
| """Generate text from a string prompt using the model and encoder/decoder.""" |
| enc_text = text_encoder(prompt).reshape(1, -1) |
| generated = generate_sequence( |
| model, |
| enc_text, |
| max_len, |
| device, |
| temperature=temperature, |
| include_prompt=include_prompt, |
| ) |
| return text_decoder(generated) |
|
|