kashif HF Staff commited on
Commit
93f1f4c
·
verified ·
1 Parent(s): 094fb45

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +27 -2
custom_generate/generate.py CHANGED
@@ -1,6 +1,7 @@
1
  from collections import deque
2
  from typing import Any, Optional, Union
3
 
 
4
  import torch
5
  import torch.nn.functional as F
6
 
@@ -40,6 +41,12 @@ def generate(
40
  depending on `return_dict_in_generate` and model type.
41
  """
42
 
 
 
 
 
 
 
43
  # Get DeepCONF parameters from generation_config or set defaults
44
  enable_conf = getattr(generation_config, "enable_conf", False)
45
  window_size = getattr(generation_config, "window_size", 2048)
@@ -74,9 +81,29 @@ def generate(
74
  output_logits = generation_config.output_logits
75
  return_dict_in_generate = generation_config.return_dict_in_generate
76
  output_confidences = getattr(generation_config, "output_confidences", False)
 
 
 
 
77
  has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
78
  do_sample = generation_config.do_sample
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Initialize attention / hidden states / scores tuples
81
  scores = () if (return_dict_in_generate and output_scores) else None
82
  raw_logits = () if (return_dict_in_generate and output_logits) else None
@@ -99,8 +126,6 @@ def generate(
99
  conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
100
  conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
101
 
102
- # Initialize via prepare_inputs_for_generation
103
-
104
  # Optional per-step confidences for debugging/visualization
105
  step_confidences = [] if (return_dict_in_generate and output_confidences) else None
106
 
 
1
  from collections import deque
2
  from typing import Any, Optional, Union
3
 
4
+ import numpy as np
5
  import torch
6
  import torch.nn.functional as F
7
 
 
41
  depending on `return_dict_in_generate` and model type.
42
  """
43
 
44
+ # Ensure processors/criteria are defined
45
+ if logits_processor is None:
46
+ logits_processor = LogitsProcessorList()
47
+ if stopping_criteria is None:
48
+ stopping_criteria = StoppingCriteriaList()
49
+
50
  # Get DeepCONF parameters from generation_config or set defaults
51
  enable_conf = getattr(generation_config, "enable_conf", False)
52
  window_size = getattr(generation_config, "window_size", 2048)
 
81
  output_logits = generation_config.output_logits
82
  return_dict_in_generate = generation_config.return_dict_in_generate
83
  output_confidences = getattr(generation_config, "output_confidences", False)
84
+ # Optional DeepConf variant helpers (compute threshold from warmup confidences)
85
+ deepconf_variant = getattr(generation_config, "deepconf_variant", None) # "low" or "high"
86
+ deepconf_eta = getattr(generation_config, "deepconf_eta", None) # float in (0,1)
87
+ deepconf_warmup_confidences = getattr(generation_config, "deepconf_warmup_confidences", None) # list/1D tensor
88
  has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
89
  do_sample = generation_config.do_sample
90
 
91
+ # If a variant is requested and a warmup set of confidences is provided, derive the threshold
92
+ if enable_conf and threshold is not None:
93
+ pass
94
+ elif enable_conf and deepconf_variant is not None and deepconf_warmup_confidences is not None:
95
+ confs = deepconf_warmup_confidences
96
+ if hasattr(confs, "detach"):
97
+ confs = confs.detach().cpu().numpy()
98
+ elif isinstance(confs, torch.Tensor):
99
+ confs = confs.cpu().numpy()
100
+ confs = np.asarray(confs, dtype=np.float32).ravel()
101
+ eta = deepconf_eta
102
+ if eta is None:
103
+ eta = 0.1 if deepconf_variant == "low" else 0.9 if deepconf_variant == "high" else 0.5
104
+ pct = max(0.0, min(100.0, 100.0 - (eta * 100.0)))
105
+ threshold = float(np.percentile(confs, pct))
106
+
107
  # Initialize attention / hidden states / scores tuples
108
  scores = () if (return_dict_in_generate and output_scores) else None
109
  raw_logits = () if (return_dict_in_generate and output_logits) else None
 
126
  conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
127
  conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
128
 
 
 
129
  # Optional per-step confidences for debugging/visualization
130
  step_confidences = [] if (return_dict_in_generate and output_confidences) else None
131