kashif HF Staff commited on
Commit
cfa4f52
·
1 Parent(s): 9ed69b6

fix generation

Browse files
Files changed (2) hide show
  1. README.md +25 -1
  2. custom_generate/generate.py +10 -20
README.md CHANGED
@@ -13,13 +13,14 @@ This repository implements the DeepCONF (Deep Confidence-based Early Stopping) g
13
 
14
  ## Overview
15
 
16
- DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold.
17
 
18
  ## Parameters
19
 
20
  - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
21
  - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
22
  - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
 
23
  - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
24
 
25
  ## Usage
@@ -108,6 +109,29 @@ out = model.generate(
108
  )
109
  ```
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  ## Requirements
112
 
113
  - PyTorch >= 1.13.0
 
13
 
14
  ## Overview
15
 
16
+ DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold. The confidence is calculated as the negative mean log probability of the top-k tokens from the full vocabulary (before sampling/filtering is applied), following the methodology from the [official DeepConf implementation](https://github.com/facebookresearch/deepconf).
17
 
18
  ## Parameters
19
 
20
  - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
21
  - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
22
  - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
23
+ - `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20` (matches official implementation).
24
  - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
25
 
26
  ## Usage
 
109
  )
110
  ```
111
 
112
+ ## Technical Details
113
+
114
+ ### Confidence Calculation
115
+
116
+ The confidence score for each generated token is calculated as follows:
117
+
118
+ 1. **Extract top-k tokens**: Get the top-k (default: 20) tokens with highest probabilities from the full vocabulary
119
+ 2. **Compute log probabilities**: Calculate log probabilities for these top-k tokens
120
+ 3. **Average**: The confidence score is `-mean(log_probs)` of the top-k tokens
121
+
122
+ This approach:
123
+ - Uses the **full probability distribution** (before any top-k/top-p/temperature filtering)
124
+ - Always considers a **fixed number of tokens** (conf_topk=20)
125
+ - Naturally **includes the sampled token** if it's in the top-k
126
+ - Matches the **official DeepConf implementation** exactly
127
+
128
+ ### Online Stopping
129
+
130
+ The online method uses a sliding window of confidence scores:
131
+ - Maintains a window of the last `window_size` (default: 2048) confidence scores
132
+ - Calculates the mean confidence over this window
133
+ - Stops generation when: `mean_confidence < threshold`
134
+
135
  ## Requirements
136
 
137
  - PyTorch >= 1.13.0
custom_generate/generate.py CHANGED
@@ -51,6 +51,7 @@ def generate(
51
  enable_conf = getattr(generation_config, "enable_conf", False)
52
  window_size = getattr(generation_config, "window_size", 2048)
53
  threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
 
54
 
55
  # If DeepCONF is not enabled, fall back to standard sampling
56
  if not enable_conf:
@@ -197,11 +198,10 @@ def generate(
197
  else:
198
  next_tokens = torch.argmax(next_token_scores, dim=-1)
199
 
200
- # Calculate confidence using only top-k/top-p filtered candidates (post-logits processors),
201
- # excluding the sampled token.
202
- # We consider candidates where logits are finite after warpers (e.g., top-k/top-p/temperature).
203
- logprobs = F.log_softmax(next_token_scores, dim=-1)
204
- candidate_mask = torch.isfinite(next_token_scores)
205
 
206
  deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
207
  step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
@@ -210,21 +210,11 @@ def generate(
210
  if not unfinished_sequences[i]:
211
  continue
212
 
213
- # Count valid candidates
214
- num_candidates = int(candidate_mask[i].sum().item())
215
- if num_candidates <= 1:
216
- conf = 0.0
217
- else:
218
- # Sum logprobs over valid candidates and exclude the sampled token's logprob
219
- total_lp = torch.sum(logprobs[i][candidate_mask[i]])
220
- selected_lp = (
221
- logprobs[i, next_tokens[i]]
222
- if candidate_mask[i, next_tokens[i]]
223
- else torch.tensor(0.0, device=logprobs.device)
224
- )
225
- denom = num_candidates - 1
226
- # Negative mean of non-selected candidate logprobs
227
- conf = -((total_lp - selected_lp) / denom).item()
228
 
229
  # Update tracking structures
230
  if len(conf_group_lists[i]) >= window_size:
 
51
  enable_conf = getattr(generation_config, "enable_conf", False)
52
  window_size = getattr(generation_config, "window_size", 2048)
53
  threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
54
+ conf_topk = getattr(generation_config, "conf_topk", 20) # Number of top tokens for confidence calculation
55
 
56
  # If DeepCONF is not enabled, fall back to standard sampling
57
  if not enable_conf:
 
198
  else:
199
  next_tokens = torch.argmax(next_token_scores, dim=-1)
200
 
201
+ # Calculate confidence using top-k tokens from the full probability distribution
202
+ # (before any filtering), following the official DeepConf implementation.
203
+ # This uses the raw logits (next_token_logits) before warpers are applied.
204
+ probs = F.softmax(next_token_logits, dim=-1)
 
205
 
206
  deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
207
  step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
 
210
  if not unfinished_sequences[i]:
211
  continue
212
 
213
+ # Get top-k tokens from full probability distribution
214
+ top_probs, _ = torch.topk(probs[i], k=conf_topk, dim=-1)
215
+ log_probs = torch.log(top_probs)
216
+ # Confidence is negative mean of log probabilities of top-k tokens
217
+ conf = -log_probs.mean().item()
 
 
 
 
 
 
 
 
 
 
218
 
219
  # Update tracking structures
220
  if len(conf_group_lists[i]) >= window_size: