fix generation
Browse files- README.md +25 -1
- custom_generate/generate.py +10 -20
README.md
CHANGED
|
@@ -13,13 +13,14 @@ This repository implements the DeepCONF (Deep Confidence-based Early Stopping) g
|
|
| 13 |
|
| 14 |
## Overview
|
| 15 |
|
| 16 |
-
DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold.
|
| 17 |
|
| 18 |
## Parameters
|
| 19 |
|
| 20 |
- `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
|
| 21 |
- `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
|
| 22 |
- `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
|
|
|
|
| 23 |
- `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
|
| 24 |
|
| 25 |
## Usage
|
|
@@ -108,6 +109,29 @@ out = model.generate(
|
|
| 108 |
)
|
| 109 |
```
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
## Requirements
|
| 112 |
|
| 113 |
- PyTorch >= 1.13.0
|
|
|
|
| 13 |
|
| 14 |
## Overview
|
| 15 |
|
| 16 |
+
DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold. The confidence is calculated as the negative mean log probability of the top-k tokens from the full vocabulary (before sampling/filtering is applied), following the methodology from the [official DeepConf implementation](https://github.com/facebookresearch/deepconf).
|
| 17 |
|
| 18 |
## Parameters
|
| 19 |
|
| 20 |
- `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
|
| 21 |
- `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
|
| 22 |
- `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
|
| 23 |
+
- `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20` (matches official implementation).
|
| 24 |
- `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
|
| 25 |
|
| 26 |
## Usage
|
|
|
|
| 109 |
)
|
| 110 |
```
|
| 111 |
|
| 112 |
+
## Technical Details
|
| 113 |
+
|
| 114 |
+
### Confidence Calculation
|
| 115 |
+
|
| 116 |
+
The confidence score for each generated token is calculated as follows:
|
| 117 |
+
|
| 118 |
+
1. **Extract top-k tokens**: Get the top-k (default: 20) tokens with highest probabilities from the full vocabulary
|
| 119 |
+
2. **Compute log probabilities**: Calculate log probabilities for these top-k tokens
|
| 120 |
+
3. **Average**: The confidence score is `-mean(log_probs)` of the top-k tokens
|
| 121 |
+
|
| 122 |
+
This approach:
|
| 123 |
+
- Uses the **full probability distribution** (before any top-k/top-p/temperature filtering)
|
| 124 |
+
- Always considers a **fixed number of tokens** (conf_topk=20)
|
| 125 |
+
- Naturally **includes the sampled token** if it's in the top-k
|
| 126 |
+
- Matches the **official DeepConf implementation** exactly
|
| 127 |
+
|
| 128 |
+
### Online Stopping
|
| 129 |
+
|
| 130 |
+
The online method uses a sliding window of confidence scores:
|
| 131 |
+
- Maintains a window of the last `window_size` (default: 2048) confidence scores
|
| 132 |
+
- Calculates the mean confidence over this window
|
| 133 |
+
- Stops generation when: `mean_confidence < threshold`
|
| 134 |
+
|
| 135 |
## Requirements
|
| 136 |
|
| 137 |
- PyTorch >= 1.13.0
|
custom_generate/generate.py
CHANGED
|
@@ -51,6 +51,7 @@ def generate(
|
|
| 51 |
enable_conf = getattr(generation_config, "enable_conf", False)
|
| 52 |
window_size = getattr(generation_config, "window_size", 2048)
|
| 53 |
threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
|
|
|
|
| 54 |
|
| 55 |
# If DeepCONF is not enabled, fall back to standard sampling
|
| 56 |
if not enable_conf:
|
|
@@ -197,11 +198,10 @@ def generate(
|
|
| 197 |
else:
|
| 198 |
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 199 |
|
| 200 |
-
# Calculate confidence using
|
| 201 |
-
#
|
| 202 |
-
#
|
| 203 |
-
|
| 204 |
-
candidate_mask = torch.isfinite(next_token_scores)
|
| 205 |
|
| 206 |
deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
|
| 207 |
step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
|
|
@@ -210,21 +210,11 @@ def generate(
|
|
| 210 |
if not unfinished_sequences[i]:
|
| 211 |
continue
|
| 212 |
|
| 213 |
-
#
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# Sum logprobs over valid candidates and exclude the sampled token's logprob
|
| 219 |
-
total_lp = torch.sum(logprobs[i][candidate_mask[i]])
|
| 220 |
-
selected_lp = (
|
| 221 |
-
logprobs[i, next_tokens[i]]
|
| 222 |
-
if candidate_mask[i, next_tokens[i]]
|
| 223 |
-
else torch.tensor(0.0, device=logprobs.device)
|
| 224 |
-
)
|
| 225 |
-
denom = num_candidates - 1
|
| 226 |
-
# Negative mean of non-selected candidate logprobs
|
| 227 |
-
conf = -((total_lp - selected_lp) / denom).item()
|
| 228 |
|
| 229 |
# Update tracking structures
|
| 230 |
if len(conf_group_lists[i]) >= window_size:
|
|
|
|
| 51 |
enable_conf = getattr(generation_config, "enable_conf", False)
|
| 52 |
window_size = getattr(generation_config, "window_size", 2048)
|
| 53 |
threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
|
| 54 |
+
conf_topk = getattr(generation_config, "conf_topk", 20) # Number of top tokens for confidence calculation
|
| 55 |
|
| 56 |
# If DeepCONF is not enabled, fall back to standard sampling
|
| 57 |
if not enable_conf:
|
|
|
|
| 198 |
else:
|
| 199 |
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 200 |
|
| 201 |
+
# Calculate confidence using top-k tokens from the full probability distribution
|
| 202 |
+
# (before any filtering), following the official DeepConf implementation.
|
| 203 |
+
# This uses the raw logits (next_token_logits) before warpers are applied.
|
| 204 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
|
|
|
| 205 |
|
| 206 |
deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
|
| 207 |
step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
|
|
|
|
| 210 |
if not unfinished_sequences[i]:
|
| 211 |
continue
|
| 212 |
|
| 213 |
+
# Get top-k tokens from full probability distribution
|
| 214 |
+
top_probs, _ = torch.topk(probs[i], k=conf_topk, dim=-1)
|
| 215 |
+
log_probs = torch.log(top_probs)
|
| 216 |
+
# Confidence is negative mean of log probabilities of top-k tokens
|
| 217 |
+
conf = -log_probs.mean().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
# Update tracking structures
|
| 220 |
if len(conf_group_lists[i]) >= window_size:
|