| | --- |
| | library_name: transformers |
| | tags: |
| | - custom_generate |
| | - sampling |
| | - kvcache |
| | --- |
| | |
| | # Sampling with KV Cache |
| |
|
| | ## Description |
| | A clean, hackable implementation of sampling (also called ancestral sampling or multinomial sampling) with full KV cache support. This is a simplified alternative to the complex generation mixin in transformers, designed for readability and ease of modification while maintaining full performance. |
| |
|
| | The implementation supports both sampling and greedy decoding modes, with optional temperature scaling and top-k/top-p filtering. |
| |
|
| | ## Base model |
| | - [HuggingFaceTB/SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) |
| |
|
| | ## Model compatibility |
| | Most transformer LLM/VLM models trained for causal language modeling. |
| |
|
| | ## Relevant Arguments |
| | - `temperature` (float): Sampling temperature (default: 1.0, higher = more random) |
| | - `top_k` (int): Only consider top-k most probable tokens (default: None) |
| | - `top_p` (float): Only consider tokens with cumulative probability <= top_p (default: None) |
| | - `do_sample` (bool): Whether to use sampling (True, default) or greedy decoding (False) |
| |
|
| | ### Logits Processing Order |
| | Logits processors are applied in sequence: `temperature → softmax → top_k → top_p` (same as HuggingFace's `LogitProcessor` system). Temperature scaling occurs before top-p filtering, affecting the probability distribution that top-p operates on. |
| |
|
| | For example, with `temperature=1.0`, `top_p=0.9` might include tokens A, B, C. With `temperature=0.5`, probability mass is much more concentrated, so `top_p=0.9` might only include token A. |
| |
|
| | ## Outputs |
| | When `return_dict_in_generate=True`, returns a dictionary with: |
| | - `sequences`: Generated token IDs |
| | - `scores`: Log probabilities of sampled tokens (with temperature/sampling modifications) |
| | - `logprobs`: Original model log probabilities (T=1, no modifications) |
| | Otherwise, returns a tensor of generated token IDs. |
| |
|
| | ## Example usage |
| |
|
| | ```py |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | |
| | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
| | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto") |
| | |
| | inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device) |
| | |
| | # Basic sampling |
| | gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", trust_remote_code=True) |
| | |
| | # With temperature |
| | gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", temperature=0.8, trust_remote_code=True) |
| | |
| | # With top-k |
| | gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", top_k=50, trust_remote_code=True) |
| | |
| | # With top-p (nucleus sampling) |
| | gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", top_p=0.9, trust_remote_code=True) |
| | |
| | # Greedy decoding (no sampling) |
| | gen_out = model.generate(**inputs, custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", do_sample=False, trust_remote_code=True) |
| | |
| | # Get detailed output with probabilities |
| | gen_out = model.generate( |
| | **inputs, |
| | custom_generate="manueldeprada/sampling_with_kvcache_hf_helpers", |
| | return_dict_in_generate=True, |
| | trust_remote_code=True |
| | ) |
| | print(f"Generated text: {tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=True)}") |
| | print(f"Sampling scores: {gen_out['scores']}") |
| | print(f"Model log probabilities: {gen_out['logprobs']}") |
| | ``` |
| |
|
| | ## Algorithm |
| | 1. Initialize KV cache and prepare input sequences |
| | 2. For each generation step: |
| | - Get logits from the model for the current sequence |
| | - Apply temperature scaling to logits |
| | - Optionally apply top-k filtering (keep only top-k tokens) |
| | - Optionally apply top-p filtering (nucleus sampling) |
| | - Convert to probabilities using softmax |
| | - Sample from the probability distribution (or take argmax for greedy) |
| | - Append the selected token to the sequence |
| | - Update KV cache and track sequence completion |
| | 3. Return generated sequences and probability information |
| |
|
| |
|
| |
|
| |
|
| |
|