import torch from typing import Callable def sample_next_token(model, context: torch.Tensor, device, temperature: float = 1.0) -> torch.Tensor: """Sample the next token from the model given a context sequence.""" assert context.ndim == 2, "context should be (batch_size, seq_len)" model.eval() with torch.no_grad(): context = context.to(device) logits = model(context)[:, :, -1] # (batch_size, vocab_size) scaled_logits = logits / temperature probs = torch.softmax(scaled_logits, dim=-1) return torch.multinomial(probs, num_samples=1) # (batch_size, 1) def generate_sequence( model, prompt: torch.Tensor, max_len: int, device, temperature: float = 1.0, include_prompt: bool = False, ) -> torch.Tensor: """Autoregressively generate a sequence of tokens from a prompt.""" assert prompt.ndim == 2, "prompt should be (batch_size, seq_len)" context = prompt.to(device) for _ in range(max_len): next_token = sample_next_token(model, context, device, temperature=temperature) context = torch.concat([context, next_token], dim=-1) return context[0, len(prompt):] if not include_prompt else context[0, :] def generate_text( model, prompt: str, text_encoder: Callable[[str], torch.Tensor], text_decoder: Callable[[torch.Tensor], str], device, max_len: int = 128, temperature: float = 1.0, include_prompt: bool = True, ) -> str: """Generate text from a string prompt using the model and encoder/decoder.""" enc_text = text_encoder(prompt).reshape(1, -1) # (batch_size, seq_len) generated = generate_sequence( model, enc_text, max_len, device, temperature=temperature, include_prompt=include_prompt, ) return text_decoder(generated)