CharRNN / src /inference.py
hoom4n's picture
Upload 18 files
b6447fa verified
import torch
from typing import Callable
def sample_next_token(model, context: torch.Tensor, device, temperature: float = 1.0) -> torch.Tensor:
"""Sample the next token from the model given a context sequence."""
assert context.ndim == 2, "context should be (batch_size, seq_len)"
model.eval()
with torch.no_grad():
context = context.to(device)
logits = model(context)[:, :, -1] # (batch_size, vocab_size)
scaled_logits = logits / temperature
probs = torch.softmax(scaled_logits, dim=-1)
return torch.multinomial(probs, num_samples=1) # (batch_size, 1)
def generate_sequence(
model,
prompt: torch.Tensor,
max_len: int,
device,
temperature: float = 1.0,
include_prompt: bool = False,
) -> torch.Tensor:
"""Autoregressively generate a sequence of tokens from a prompt."""
assert prompt.ndim == 2, "prompt should be (batch_size, seq_len)"
context = prompt.to(device)
for _ in range(max_len):
next_token = sample_next_token(model, context, device, temperature=temperature)
context = torch.concat([context, next_token], dim=-1)
return context[0, len(prompt):] if not include_prompt else context[0, :]
def generate_text(
model,
prompt: str,
text_encoder: Callable[[str], torch.Tensor],
text_decoder: Callable[[torch.Tensor], str],
device,
max_len: int = 128,
temperature: float = 1.0,
include_prompt: bool = True,
) -> str:
"""Generate text from a string prompt using the model and encoder/decoder."""
enc_text = text_encoder(prompt).reshape(1, -1) # (batch_size, seq_len)
generated = generate_sequence(
model,
enc_text,
max_len,
device,
temperature=temperature,
include_prompt=include_prompt,
)
return text_decoder(generated)