MathBite commited on
Commit
86b391d
·
verified ·
1 Parent(s): 231a0c8

updated inference

Browse files
Files changed (1) hide show
  1. modeling.py +16 -5
modeling.py CHANGED
@@ -13,7 +13,7 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
13
  def __init__(self, config):
14
  super().__init__(config)
15
 
16
- self.lookup_length = getattr(config, "lookup_length", 30)
17
  self.num_new_tokens = 2
18
  self.deletion_threshold = config.deletion_threshold if "deletion_threshold" in config else 0.7
19
 
@@ -112,7 +112,6 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
112
  ):
113
  """
114
  Custom generate method to orchestrate self-correction.
115
-
116
  NOTE: This implementation currently only supports a batch size of 1.
117
  """
118
  # Set the model to evaluation mode and cache instruction tokens.
@@ -123,6 +122,10 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
123
  generated_ids = input_ids
124
  attention_mask = torch.ones_like(input_ids)
125
 
 
 
 
 
126
  # The first forward pass processes the prompt and gets the initial KV cache.
127
  outputs = self(
128
  input_ids=input_ids,
@@ -142,11 +145,16 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
142
  # Apply softmax to get hallucination probabilities.
143
  hallucination_probs = F.softmax(hallucination_logits, dim=-1)
144
 
145
- # Conditionally choose the next tokens based on the detector's output.
146
- if hallucination_probs[0, 1] > self.deletion_threshold:
 
 
 
147
  current_tokens = self.rewrite_sentence_ids
148
- elif hallucination_probs[0, 2] > self.deletion_threshold:
 
149
  current_tokens = self.rewrite_response_ids
 
150
  else:
151
  if temperature > 0.0:
152
  scaled_logits = next_token_logits / temperature
@@ -154,6 +162,9 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
154
  current_tokens = torch.multinomial(probs, num_samples=1)
155
  else:
156
  current_tokens = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
 
 
 
157
 
158
  generated_ids = torch.cat([generated_ids, current_tokens], dim=-1)
159
 
 
13
  def __init__(self, config):
14
  super().__init__(config)
15
 
16
+ self.correction_cooldown = getattr(config, "correction_cooldown", 30)
17
  self.num_new_tokens = 2
18
  self.deletion_threshold = config.deletion_threshold if "deletion_threshold" in config else 0.7
19
 
 
112
  ):
113
  """
114
  Custom generate method to orchestrate self-correction.
 
115
  NOTE: This implementation currently only supports a batch size of 1.
116
  """
117
  # Set the model to evaluation mode and cache instruction tokens.
 
122
  generated_ids = input_ids
123
  attention_mask = torch.ones_like(input_ids)
124
 
125
+ # Initialize a counter to track tokens since the last correction.
126
+ # Start it at the cooldown value to allow immediate correction if needed.
127
+ tokens_since_correction = self.correction_cooldown
128
+
129
  # The first forward pass processes the prompt and gets the initial KV cache.
130
  outputs = self(
131
  input_ids=input_ids,
 
145
  # Apply softmax to get hallucination probabilities.
146
  hallucination_probs = F.softmax(hallucination_logits, dim=-1)
147
 
148
+ # Check if the cooldown period has passed.
149
+ can_correct = tokens_since_correction >= self.correction_cooldown
150
+
151
+ # Conditionally choose the next tokens based on the detector's output and the cooldown.
152
+ if can_correct and hallucination_probs[0, 1] > self.deletion_threshold:
153
  current_tokens = self.rewrite_sentence_ids
154
+ tokens_since_correction = 0 # Reset the counter
155
+ elif can_correct and hallucination_probs[0, 2] > self.deletion_threshold:
156
  current_tokens = self.rewrite_response_ids
157
+ tokens_since_correction = 0 # Reset the counter
158
  else:
159
  if temperature > 0.0:
160
  scaled_logits = next_token_logits / temperature
 
162
  current_tokens = torch.multinomial(probs, num_samples=1)
163
  else:
164
  current_tokens = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
165
+
166
+ # Increment the counter by the number of tokens just generated.
167
+ tokens_since_correction += current_tokens.shape[1]
168
 
169
  generated_ids = torch.cat([generated_ids, current_tokens], dim=-1)
170