Update custom_generate/generate.py
Browse files- custom_generate/generate.py +27 -2
custom_generate/generate.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from collections import deque
|
| 2 |
from typing import Any, Optional, Union
|
| 3 |
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
| 6 |
|
|
@@ -40,6 +41,12 @@ def generate(
|
|
| 40 |
depending on `return_dict_in_generate` and model type.
|
| 41 |
"""
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# Get DeepCONF parameters from generation_config or set defaults
|
| 44 |
enable_conf = getattr(generation_config, "enable_conf", False)
|
| 45 |
window_size = getattr(generation_config, "window_size", 2048)
|
|
@@ -74,9 +81,29 @@ def generate(
|
|
| 74 |
output_logits = generation_config.output_logits
|
| 75 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 76 |
output_confidences = getattr(generation_config, "output_confidences", False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
| 78 |
do_sample = generation_config.do_sample
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Initialize attention / hidden states / scores tuples
|
| 81 |
scores = () if (return_dict_in_generate and output_scores) else None
|
| 82 |
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
@@ -99,8 +126,6 @@ def generate(
|
|
| 99 |
conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
|
| 100 |
conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
|
| 101 |
|
| 102 |
-
# Initialize via prepare_inputs_for_generation
|
| 103 |
-
|
| 104 |
# Optional per-step confidences for debugging/visualization
|
| 105 |
step_confidences = [] if (return_dict_in_generate and output_confidences) else None
|
| 106 |
|
|
|
|
| 1 |
from collections import deque
|
| 2 |
from typing import Any, Optional, Union
|
| 3 |
|
| 4 |
+
import numpy as np
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
|
|
|
| 41 |
depending on `return_dict_in_generate` and model type.
|
| 42 |
"""
|
| 43 |
|
| 44 |
+
# Ensure processors/criteria are defined
|
| 45 |
+
if logits_processor is None:
|
| 46 |
+
logits_processor = LogitsProcessorList()
|
| 47 |
+
if stopping_criteria is None:
|
| 48 |
+
stopping_criteria = StoppingCriteriaList()
|
| 49 |
+
|
| 50 |
# Get DeepCONF parameters from generation_config or set defaults
|
| 51 |
enable_conf = getattr(generation_config, "enable_conf", False)
|
| 52 |
window_size = getattr(generation_config, "window_size", 2048)
|
|
|
|
| 81 |
output_logits = generation_config.output_logits
|
| 82 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 83 |
output_confidences = getattr(generation_config, "output_confidences", False)
|
| 84 |
+
# Optional DeepConf variant helpers (compute threshold from warmup confidences)
|
| 85 |
+
deepconf_variant = getattr(generation_config, "deepconf_variant", None) # "low" or "high"
|
| 86 |
+
deepconf_eta = getattr(generation_config, "deepconf_eta", None) # float in (0,1)
|
| 87 |
+
deepconf_warmup_confidences = getattr(generation_config, "deepconf_warmup_confidences", None) # list/1D tensor
|
| 88 |
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
| 89 |
do_sample = generation_config.do_sample
|
| 90 |
|
| 91 |
+
# If a variant is requested and a warmup set of confidences is provided, derive the threshold
|
| 92 |
+
if enable_conf and threshold is not None:
|
| 93 |
+
pass
|
| 94 |
+
elif enable_conf and deepconf_variant is not None and deepconf_warmup_confidences is not None:
|
| 95 |
+
confs = deepconf_warmup_confidences
|
| 96 |
+
if hasattr(confs, "detach"):
|
| 97 |
+
confs = confs.detach().cpu().numpy()
|
| 98 |
+
elif isinstance(confs, torch.Tensor):
|
| 99 |
+
confs = confs.cpu().numpy()
|
| 100 |
+
confs = np.asarray(confs, dtype=np.float32).ravel()
|
| 101 |
+
eta = deepconf_eta
|
| 102 |
+
if eta is None:
|
| 103 |
+
eta = 0.1 if deepconf_variant == "low" else 0.9 if deepconf_variant == "high" else 0.5
|
| 104 |
+
pct = max(0.0, min(100.0, 100.0 - (eta * 100.0)))
|
| 105 |
+
threshold = float(np.percentile(confs, pct))
|
| 106 |
+
|
| 107 |
# Initialize attention / hidden states / scores tuples
|
| 108 |
scores = () if (return_dict_in_generate and output_scores) else None
|
| 109 |
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
|
|
| 126 |
conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
|
| 127 |
conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
|
| 128 |
|
|
|
|
|
|
|
| 129 |
# Optional per-step confidences for debugging/visualization
|
| 130 |
step_confidences = [] if (return_dict_in_generate and output_confidences) else None
|
| 131 |
|