File size: 12,118 Bytes
df24cd5 e226c73 094fb45 cfa4f52 094fb45 56bd97c 094fb45 30add1f 094fb45 56bd97c 094fb45 30add1f 094fb45 30add1f 094fb45 30add1f 094fb45 30add1f 094fb45 30add1f 094fb45 30add1f 4a2373b 30add1f 4a2373b 30add1f 4a2373b 30add1f 094fb45 4a2373b 30add1f 9ed69b6 4a2373b 30add1f 4a2373b 30add1f 4a2373b 30add1f 4a2373b 30add1f 094fb45 30add1f 56bd97c cfa4f52 094fb45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 |
---
license: apache-2.0
library_name: transformers
tags:
- custom_generate
- sampling
---
# DeepCONF Custom Generation Strategy
This repository implements the DeepCONF (Deep Confidence-based Early Stopping) generation strategy for Hugging Face Transformers models, following the [Deep Think with Confidence](https://jiaweizzhao.github.io/deepconf/) approach from the paper [Deep Think with Confidence](https://huggingface.co/papers/2508.15260).
## Overview
DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold. The confidence is calculated as the negative mean log probability of the top-k tokens from the full vocabulary (before sampling/filtering is applied), following the methodology from the [official DeepConf implementation](https://github.com/facebookresearch/deepconf).
## Parameters
- `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
- `enable_early_stopping` (bool): Whether to apply early stopping during generation (online mode) or just track confidences for post-processing (batch mode). Defaults to `True`.
- `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
- `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
- `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`.
- `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
- `deepconf_variant` (str): Optional variant for automatic threshold calibration (`"low"` or `"high"`). Requires `deepconf_warmup_confidences`.
- `deepconf_warmup_confidences` (list/tensor): Warmup confidence values for threshold calibration. Used with `deepconf_variant`.
- `deepconf_eta` (float): Optional override for eta value in threshold calculation (defaults: 0.1 for low, 0.9 for high).
## Usage
### Basic Usage
To use this custom generation strategy, you can pass it directly to the `generate` method:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"your-model",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("your-model")
# Prepare your prompt
question = "What is the square root of 144?"
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Configure generation with DeepCONF
gen_config = GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.95,
max_new_tokens=512,
enable_conf=True, # Enable DeepCONF
window_size=2048, # Sliding window size
threshold=17.0, # Confidence threshold
conf_topk=20, # Top-k for confidence (default: 20)
output_confidences=True, # Return confidence scores
return_dict_in_generate=True, # Required for confidence output
)
# Generate with DeepCONF (Hub repo)
outputs = model.generate(
**inputs,
generation_config=gen_config,
custom_generate="kashif/DeepConf", # Hugging Face Hub repo
trust_remote_code=True
)
# Access results
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
print(f"Generated: {generated_text}")
# Access per-step confidences if requested
if hasattr(outputs, 'confidences'):
confidences = outputs.confidences # Shape: (batch_size, num_generated_tokens)
print(f"Min confidence: {confidences.min().item():.3f}")
print(f"Mean confidence: {confidences.mean().item():.3f}")
```
### Calibration (DeepConf-low/high)
DeepConf's online stopping threshold can be automatically derived from a warmup phase. This allows you to calibrate the threshold based on actual model behavior rather than using a fixed value.
**Step 1: Warmup Phase** - Generate multiple sequences and collect their minimum confidences:
```python
from transformers import GenerationConfig
# Prepare inputs
question = "What is 2 + 2?"
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Configure warmup generation
warmup_cfg = GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.95,
max_new_tokens=256,
enable_conf=True, # Enable confidence tracking
return_dict_in_generate=True,
output_confidences=True,
num_return_sequences=8, # Generate 8 warmup sequences
# Note: Do NOT set threshold here - warmup should run without early stopping
)
# Generate warmup sequences
warmup_out = model.generate(
**inputs,
generation_config=warmup_cfg,
custom_generate="kashif/DeepConf",
trust_remote_code=True,
)
# Extract minimum confidence per sequence (C_t = min over all steps)
warmup_C = warmup_out.confidences.min(dim=1).values.tolist()
print(f"Warmup min confidences: {warmup_C}")
```
**Step 2: Production Generation** - Use warmup confidences to auto-derive threshold:
```python
# Configure production generation with calibrated threshold
gen_cfg = GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.95,
max_new_tokens=512,
enable_conf=True,
return_dict_in_generate=True,
output_confidences=True,
# Automatic threshold calibration
deepconf_variant="low", # "low" (aggressive, 90th percentile) or "high" (permissive, 10th percentile)
deepconf_warmup_confidences=warmup_C, # Pass warmup confidences
# Optional: deepconf_eta=0.1, # Override eta (defaults: 0.1 for low, 0.9 for high)
)
# Generate with calibrated threshold
outputs = model.generate(
**inputs,
generation_config=gen_cfg,
custom_generate="kashif/DeepConf",
trust_remote_code=True,
)
print(f"Generated: {tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)}")
```
**Variant Explanation:**
- **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping
- **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation
### Two Modes of Operation
DeepConf supports two modes that match different use cases:
#### Mode 1: Online Early Stopping (Default)
This is the default behavior where early stopping happens **during** generation:
```python
# Online mode: Stop immediately when confidence drops
gen_config = GenerationConfig(
enable_conf=True,
enable_early_stopping=True, # Default: True (online stopping)
threshold=17.0,
window_size=2048,
max_new_tokens=512,
)
outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf")
```
**Use cases:**
- Interactive generation where you want immediate results
- Real-time applications
- Single-sequence generation
- Lower memory usage (no need to store full sequences)
#### Mode 2: Batch Generation + Post-Processing
Generate multiple sequences without early stopping, then analyze them afterward:
```python
import torch
# Phase 1: Generate multiple sequences WITHOUT early stopping
gen_config = GenerationConfig(
enable_conf=True,
enable_early_stopping=False, # Disable online stopping
output_confidences=True,
return_dict_in_generate=True,
max_new_tokens=64,
)
# Expand inputs for batch generation (e.g., 8 sequences)
num_sequences = 8
expanded_input_ids = inputs.input_ids.repeat(num_sequences, 1)
if 'attention_mask' in inputs and inputs.attention_mask is not None:
expanded_attention_mask = inputs.attention_mask.repeat(num_sequences, 1)
else:
expanded_attention_mask = None
# Generate batch
outputs = model.generate(
input_ids=expanded_input_ids,
attention_mask=expanded_attention_mask,
generation_config=gen_config,
custom_generate="kashif/DeepConf"
)
# Phase 2: Post-process to analyze confidence patterns
from custom_generate.utils import process_batch_results
results = process_batch_results(
outputs,
tokenizer,
window_size=2048,
threshold=17.0
)
# Analyze results
print(f"Generated {results['num_traces']} sequences")
print(f"Min confidences: {results['min_confs']}")
for i, trace in enumerate(results['traces']):
print(f"\nSequence {i+1}:")
print(f" Text: {trace['text'][:100]}...")
print(f" Min confidence: {trace['min_conf']:.3f}")
print(f" Would stop early: {trace['stopped_early']}")
if trace['stopped_early']:
print(f" Stop position: {trace['stop_position']}")
```
**Use cases:**
- Research and experimentation (try different thresholds without regenerating)
- Batch serving (generate multiple candidates at once)
- Analysis and voting (like the official implementation)
- Calibration and threshold tuning
**Utility Functions:**
The `custom_generate/utils.py` module provides helper functions:
- `process_batch_results()`: Analyze batch outputs to detect early stopping positions
- `analyze_early_stopping()`: Calculate statistics on early stopping behavior
- `compute_warmup_threshold()`: Derive threshold from warmup confidences
- `extract_answer()`: Parse LaTeX `\boxed{answer}` patterns
#### Complete Workflow Example (Like Official DeepConf)
This demonstrates the full workflow matching the official implementation:
```python
# Step 1: Warmup phase - generate multiple sequences
warmup_config = GenerationConfig(
do_sample=True,
temperature=0.7,
max_new_tokens=64,
enable_conf=True,
enable_early_stopping=False, # No stopping during warmup
output_confidences=True,
return_dict_in_generate=True,
)
# Expand for 8 warmup sequences
num_warmup = 8
expanded_ids = inputs.input_ids.repeat(num_warmup, 1)
expanded_mask = inputs.attention_mask.repeat(num_warmup, 1) if 'attention_mask' in inputs else None
warmup_outputs = model.generate(
input_ids=expanded_ids,
attention_mask=expanded_mask,
generation_config=warmup_config,
custom_generate="kashif/DeepConf"
)
# Process warmup to get min confidences
from custom_generate.utils import process_batch_results, compute_warmup_threshold
warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=10)
print(f"Warmup min confidences: {warmup_results['min_confs']}")
# Step 2: Compute threshold from warmup
threshold = compute_warmup_threshold(
warmup_results['min_confs'],
variant="low" # or "high"
)
print(f"Calibrated threshold: {threshold:.3f}")
# Step 3: Final generation with calibrated threshold
final_config = GenerationConfig(
enable_conf=True,
enable_early_stopping=True, # Online stopping with calibrated threshold
threshold=threshold,
window_size=10,
max_new_tokens=128,
)
final_output = model.generate(**inputs, generation_config=final_config, custom_generate="kashif/DeepConf")
print(tokenizer.decode(final_output.sequences[0], skip_special_tokens=True))
```
## Technical Details
### Confidence Calculation
The confidence score for each generated token is calculated as follows:
1. **Extract top-k tokens**: Get the top-k (default: 20) tokens with highest probabilities from the full vocabulary
2. **Compute log probabilities**: Calculate log probabilities for these top-k tokens
3. **Average**: The confidence score is `-mean(log_probs)` of the top-k tokens
This approach:
- Uses the **full probability distribution** (before any top-k/top-p/temperature filtering)
- Always considers a **fixed number of tokens** (conf_topk=20)
- Naturally **includes the sampled token** if it's in the top-k
### Online Stopping
The online method uses a sliding window of confidence scores:
- Maintains a window of the last `window_size` (default: 2048) confidence scores
- Calculates the mean confidence over this window
- Stops generation when: `mean_confidence < threshold`
## Requirements
- PyTorch >= 1.13.0
- Transformers >= 4.35.0
|