C2C_demo / rosetta /model /sampling.py
fuvty's picture
[init] demo
5ccf219
import torch
import torch.nn.functional as F
from typing import Union
def sample_token(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> Union[int, torch.Tensor]:
"""Sample a token from logits using temperature, top-p, and top-k sampling.
Args:
logits: Token logits of shape [vocab_size] or [batch_size, vocab_size]
temperature: Temperature for sampling (>0). Higher values produce more random samples.
top_p: Top-p probability threshold for nucleus sampling (0 < top_p ≤ 1)
top_k: Top-k threshold for sampling (if -1, no top-k filtering is applied)
Returns:
Sampled token ID (int for single sample, tensor for batch)
"""
if not isinstance(logits, torch.Tensor):
raise TypeError("logits must be a torch.Tensor")
if logits.dim() not in [1, 2]:
raise ValueError("logits must have shape [vocab_size] or [batch_size, vocab_size]")
# Handle single dimension input
is_single_input = logits.dim() == 1
if is_single_input:
logits = logits.unsqueeze(0)
batch_size = logits.shape[0]
# For greedy sampling (temperature=0), just return argmax
if temperature == 0 or temperature <= 1e-5:
tokens = torch.argmax(logits, dim=-1)
return tokens.item() if is_single_input else tokens
# Convert to probabilities
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
# Apply top-k filtering first (if specified)
if top_k != -1:
# Get top-k values and indices
top_k_values, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1)
# Create a mask to zero out non-top-k probabilities
mask = torch.zeros_like(probs, dtype=torch.bool)
mask.scatter_(-1, top_k_indices, True)
# Zero out non-top-k probabilities
probs = probs * mask.float()
# Renormalize probabilities
probs = probs / probs.sum(dim=-1, keepdim=True)
# Apply top-p (nucleus) sampling
if top_p < 1.0:
# Sort probabilities in descending order
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
# Calculate cumulative probabilities
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create a mask for probabilities to keep
# Values above top_p threshold are masked out
mask = cumulative_probs <= top_p
# Always keep at least one token
mask[:, 0] = True
# Zero out masked positions to exclude them from sampling
sorted_probs = sorted_probs * mask.float()
# Renormalize probabilities
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
# Sample from the filtered distribution
sampled_indices = torch.multinomial(sorted_probs, num_samples=1)
# Map back to original vocabulary indices
tokens = torch.gather(sorted_indices, dim=-1, index=sampled_indices)
tokens = tokens.squeeze(-1) # Remove sample dimension
else:
# Direct sampling if no top-p filtering
tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
return tokens.item() if is_single_input else tokens