File size: 12,118 Bytes
df24cd5
 
 
e226c73
 
 
094fb45
 
 
 
 
 
 
 
 
cfa4f52
094fb45
 
 
 
56bd97c
094fb45
 
30add1f
094fb45
56bd97c
 
 
094fb45
 
 
30add1f
 
094fb45
 
 
30add1f
094fb45
30add1f
 
 
 
 
 
094fb45
 
30add1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
094fb45
 
 
 
30add1f
094fb45
 
 
30add1f
 
 
 
 
 
 
 
 
 
4a2373b
 
30add1f
4a2373b
30add1f
4a2373b
30add1f
094fb45
4a2373b
 
 
30add1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ed69b6
4a2373b
 
30add1f
 
 
 
4a2373b
 
30add1f
 
4a2373b
30add1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a2373b
 
 
30add1f
 
094fb45
 
30add1f
 
 
 
56bd97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa4f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
094fb45
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
---
license: apache-2.0
library_name: transformers
tags:
- custom_generate
- sampling
---


# DeepCONF Custom Generation Strategy

This repository implements the DeepCONF (Deep Confidence-based Early Stopping) generation strategy for Hugging Face Transformers models, following the [Deep Think with Confidence](https://jiaweizzhao.github.io/deepconf/) approach from the paper [Deep Think with Confidence](https://huggingface.co/papers/2508.15260).

## Overview

DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold. The confidence is calculated as the negative mean log probability of the top-k tokens from the full vocabulary (before sampling/filtering is applied), following the methodology from the [official DeepConf implementation](https://github.com/facebookresearch/deepconf).

## Parameters

- `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
- `enable_early_stopping` (bool): Whether to apply early stopping during generation (online mode) or just track confidences for post-processing (batch mode). Defaults to `True`.
- `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
- `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
- `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`.
- `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
- `deepconf_variant` (str): Optional variant for automatic threshold calibration (`"low"` or `"high"`). Requires `deepconf_warmup_confidences`.
- `deepconf_warmup_confidences` (list/tensor): Warmup confidence values for threshold calibration. Used with `deepconf_variant`.
- `deepconf_eta` (float): Optional override for eta value in threshold calculation (defaults: 0.1 for low, 0.9 for high).

## Usage

### Basic Usage

To use this custom generation strategy, you can pass it directly to the `generate` method:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "your-model",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("your-model")

# Prepare your prompt
question = "What is the square root of 144?"
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Configure generation with DeepCONF
gen_config = GenerationConfig(
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    max_new_tokens=512,
    enable_conf=True,              # Enable DeepCONF
    window_size=2048,              # Sliding window size
    threshold=17.0,                # Confidence threshold
    conf_topk=20,                  # Top-k for confidence (default: 20)
    output_confidences=True,       # Return confidence scores
    return_dict_in_generate=True,  # Required for confidence output
)

# Generate with DeepCONF (Hub repo)
outputs = model.generate(
    **inputs,
    generation_config=gen_config,
    custom_generate="kashif/DeepConf",  # Hugging Face Hub repo
    trust_remote_code=True
)

# Access results
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
print(f"Generated: {generated_text}")

# Access per-step confidences if requested
if hasattr(outputs, 'confidences'):
    confidences = outputs.confidences  # Shape: (batch_size, num_generated_tokens)
    print(f"Min confidence: {confidences.min().item():.3f}")
    print(f"Mean confidence: {confidences.mean().item():.3f}")
```

### Calibration (DeepConf-low/high)

DeepConf's online stopping threshold can be automatically derived from a warmup phase. This allows you to calibrate the threshold based on actual model behavior rather than using a fixed value.

**Step 1: Warmup Phase** - Generate multiple sequences and collect their minimum confidences:

```python
from transformers import GenerationConfig

# Prepare inputs
question = "What is 2 + 2?"
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Configure warmup generation
warmup_cfg = GenerationConfig(
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    max_new_tokens=256,
    enable_conf=True,               # Enable confidence tracking
    return_dict_in_generate=True,
    output_confidences=True,
    num_return_sequences=8,         # Generate 8 warmup sequences
    # Note: Do NOT set threshold here - warmup should run without early stopping
)

# Generate warmup sequences
warmup_out = model.generate(
    **inputs,
    generation_config=warmup_cfg,
    custom_generate="kashif/DeepConf",
    trust_remote_code=True,
)

# Extract minimum confidence per sequence (C_t = min over all steps)
warmup_C = warmup_out.confidences.min(dim=1).values.tolist()
print(f"Warmup min confidences: {warmup_C}")
```

**Step 2: Production Generation** - Use warmup confidences to auto-derive threshold:

```python
# Configure production generation with calibrated threshold
gen_cfg = GenerationConfig(
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    max_new_tokens=512,
    enable_conf=True,
    return_dict_in_generate=True,
    output_confidences=True,

    # Automatic threshold calibration
    deepconf_variant="low",  # "low" (aggressive, 90th percentile) or "high" (permissive, 10th percentile)
    deepconf_warmup_confidences=warmup_C,  # Pass warmup confidences
    # Optional: deepconf_eta=0.1,  # Override eta (defaults: 0.1 for low, 0.9 for high)
)

# Generate with calibrated threshold
outputs = model.generate(
    **inputs,
    generation_config=gen_cfg,
    custom_generate="kashif/DeepConf",
    trust_remote_code=True,
)

print(f"Generated: {tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)}")
```

**Variant Explanation:**
- **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping
- **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation

### Two Modes of Operation

DeepConf supports two modes that match different use cases:

#### Mode 1: Online Early Stopping (Default)

This is the default behavior where early stopping happens **during** generation:

```python
# Online mode: Stop immediately when confidence drops
gen_config = GenerationConfig(
    enable_conf=True,
    enable_early_stopping=True,  # Default: True (online stopping)
    threshold=17.0,
    window_size=2048,
    max_new_tokens=512,
)

outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf")
```

**Use cases:**
- Interactive generation where you want immediate results
- Real-time applications
- Single-sequence generation
- Lower memory usage (no need to store full sequences)

#### Mode 2: Batch Generation + Post-Processing

Generate multiple sequences without early stopping, then analyze them afterward:

```python
import torch

# Phase 1: Generate multiple sequences WITHOUT early stopping
gen_config = GenerationConfig(
    enable_conf=True,
    enable_early_stopping=False,  # Disable online stopping
    output_confidences=True,
    return_dict_in_generate=True,
    max_new_tokens=64,
)

# Expand inputs for batch generation (e.g., 8 sequences)
num_sequences = 8
expanded_input_ids = inputs.input_ids.repeat(num_sequences, 1)
if 'attention_mask' in inputs and inputs.attention_mask is not None:
    expanded_attention_mask = inputs.attention_mask.repeat(num_sequences, 1)
else:
    expanded_attention_mask = None

# Generate batch
outputs = model.generate(
    input_ids=expanded_input_ids,
    attention_mask=expanded_attention_mask,
    generation_config=gen_config,
    custom_generate="kashif/DeepConf"
)

# Phase 2: Post-process to analyze confidence patterns
from custom_generate.utils import process_batch_results

results = process_batch_results(
    outputs,
    tokenizer,
    window_size=2048,
    threshold=17.0
)

# Analyze results
print(f"Generated {results['num_traces']} sequences")
print(f"Min confidences: {results['min_confs']}")

for i, trace in enumerate(results['traces']):
    print(f"\nSequence {i+1}:")
    print(f"  Text: {trace['text'][:100]}...")
    print(f"  Min confidence: {trace['min_conf']:.3f}")
    print(f"  Would stop early: {trace['stopped_early']}")
    if trace['stopped_early']:
        print(f"  Stop position: {trace['stop_position']}")
```

**Use cases:**
- Research and experimentation (try different thresholds without regenerating)
- Batch serving (generate multiple candidates at once)
- Analysis and voting (like the official implementation)
- Calibration and threshold tuning

**Utility Functions:**

The `custom_generate/utils.py` module provides helper functions:

- `process_batch_results()`: Analyze batch outputs to detect early stopping positions
- `analyze_early_stopping()`: Calculate statistics on early stopping behavior
- `compute_warmup_threshold()`: Derive threshold from warmup confidences
- `extract_answer()`: Parse LaTeX `\boxed{answer}` patterns

#### Complete Workflow Example (Like Official DeepConf)

This demonstrates the full workflow matching the official implementation:

```python
# Step 1: Warmup phase - generate multiple sequences
warmup_config = GenerationConfig(
    do_sample=True,
    temperature=0.7,
    max_new_tokens=64,
    enable_conf=True,
    enable_early_stopping=False,  # No stopping during warmup
    output_confidences=True,
    return_dict_in_generate=True,
)

# Expand for 8 warmup sequences
num_warmup = 8
expanded_ids = inputs.input_ids.repeat(num_warmup, 1)
expanded_mask = inputs.attention_mask.repeat(num_warmup, 1) if 'attention_mask' in inputs else None

warmup_outputs = model.generate(
    input_ids=expanded_ids,
    attention_mask=expanded_mask,
    generation_config=warmup_config,
    custom_generate="kashif/DeepConf"
)

# Process warmup to get min confidences
from custom_generate.utils import process_batch_results, compute_warmup_threshold

warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=10)
print(f"Warmup min confidences: {warmup_results['min_confs']}")

# Step 2: Compute threshold from warmup
threshold = compute_warmup_threshold(
    warmup_results['min_confs'],
    variant="low"  # or "high"
)
print(f"Calibrated threshold: {threshold:.3f}")

# Step 3: Final generation with calibrated threshold
final_config = GenerationConfig(
    enable_conf=True,
    enable_early_stopping=True,  # Online stopping with calibrated threshold
    threshold=threshold,
    window_size=10,
    max_new_tokens=128,
)

final_output = model.generate(**inputs, generation_config=final_config, custom_generate="kashif/DeepConf")
print(tokenizer.decode(final_output.sequences[0], skip_special_tokens=True))
```

## Technical Details

### Confidence Calculation

The confidence score for each generated token is calculated as follows:

1. **Extract top-k tokens**: Get the top-k (default: 20) tokens with highest probabilities from the full vocabulary
2. **Compute log probabilities**: Calculate log probabilities for these top-k tokens
3. **Average**: The confidence score is `-mean(log_probs)` of the top-k tokens

This approach:
- Uses the **full probability distribution** (before any top-k/top-p/temperature filtering)
- Always considers a **fixed number of tokens** (conf_topk=20)
- Naturally **includes the sampled token** if it's in the top-k

### Online Stopping

The online method uses a sliding window of confidence scores:
- Maintains a window of the last `window_size` (default: 2048) confidence scores
- Calculates the mean confidence over this window
- Stops generation when: `mean_confidence < threshold`

## Requirements

- PyTorch >= 1.13.0
- Transformers >= 4.35.0