import torch import torch.nn.functional as F from typing import Union def sample_token(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> Union[int, torch.Tensor]: """Sample a token from logits using temperature, top-p, and top-k sampling. Args: logits: Token logits of shape [vocab_size] or [batch_size, vocab_size] temperature: Temperature for sampling (>0). Higher values produce more random samples. top_p: Top-p probability threshold for nucleus sampling (0 < top_p ≤ 1) top_k: Top-k threshold for sampling (if -1, no top-k filtering is applied) Returns: Sampled token ID (int for single sample, tensor for batch) """ if not isinstance(logits, torch.Tensor): raise TypeError("logits must be a torch.Tensor") if logits.dim() not in [1, 2]: raise ValueError("logits must have shape [vocab_size] or [batch_size, vocab_size]") # Handle single dimension input is_single_input = logits.dim() == 1 if is_single_input: logits = logits.unsqueeze(0) batch_size = logits.shape[0] # For greedy sampling (temperature=0), just return argmax if temperature == 0 or temperature <= 1e-5: tokens = torch.argmax(logits, dim=-1) return tokens.item() if is_single_input else tokens # Convert to probabilities probs = torch.nn.functional.softmax(logits / temperature, dim=-1) # Apply top-k filtering first (if specified) if top_k != -1: # Get top-k values and indices top_k_values, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1) # Create a mask to zero out non-top-k probabilities mask = torch.zeros_like(probs, dtype=torch.bool) mask.scatter_(-1, top_k_indices, True) # Zero out non-top-k probabilities probs = probs * mask.float() # Renormalize probabilities probs = probs / probs.sum(dim=-1, keepdim=True) # Apply top-p (nucleus) sampling if top_p < 1.0: # Sort probabilities in descending order sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) # Calculate cumulative probabilities cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Create a mask for probabilities to keep # Values above top_p threshold are masked out mask = cumulative_probs <= top_p # Always keep at least one token mask[:, 0] = True # Zero out masked positions to exclude them from sampling sorted_probs = sorted_probs * mask.float() # Renormalize probabilities sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) # Sample from the filtered distribution sampled_indices = torch.multinomial(sorted_probs, num_samples=1) # Map back to original vocabulary indices tokens = torch.gather(sorted_indices, dim=-1, index=sampled_indices) tokens = tokens.squeeze(-1) # Remove sample dimension else: # Direct sampling if no top-p filtering tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) return tokens.item() if is_single_input else tokens