updated inference
Browse files- modeling.py +16 -5
modeling.py
CHANGED
|
@@ -13,7 +13,7 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
| 13 |
def __init__(self, config):
|
| 14 |
super().__init__(config)
|
| 15 |
|
| 16 |
-
self.
|
| 17 |
self.num_new_tokens = 2
|
| 18 |
self.deletion_threshold = config.deletion_threshold if "deletion_threshold" in config else 0.7
|
| 19 |
|
|
@@ -112,7 +112,6 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
| 112 |
):
|
| 113 |
"""
|
| 114 |
Custom generate method to orchestrate self-correction.
|
| 115 |
-
|
| 116 |
NOTE: This implementation currently only supports a batch size of 1.
|
| 117 |
"""
|
| 118 |
# Set the model to evaluation mode and cache instruction tokens.
|
|
@@ -123,6 +122,10 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
| 123 |
generated_ids = input_ids
|
| 124 |
attention_mask = torch.ones_like(input_ids)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# The first forward pass processes the prompt and gets the initial KV cache.
|
| 127 |
outputs = self(
|
| 128 |
input_ids=input_ids,
|
|
@@ -142,11 +145,16 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
| 142 |
# Apply softmax to get hallucination probabilities.
|
| 143 |
hallucination_probs = F.softmax(hallucination_logits, dim=-1)
|
| 144 |
|
| 145 |
-
#
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
current_tokens = self.rewrite_sentence_ids
|
| 148 |
-
|
|
|
|
| 149 |
current_tokens = self.rewrite_response_ids
|
|
|
|
| 150 |
else:
|
| 151 |
if temperature > 0.0:
|
| 152 |
scaled_logits = next_token_logits / temperature
|
|
@@ -154,6 +162,9 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
|
|
| 154 |
current_tokens = torch.multinomial(probs, num_samples=1)
|
| 155 |
else:
|
| 156 |
current_tokens = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
generated_ids = torch.cat([generated_ids, current_tokens], dim=-1)
|
| 159 |
|
|
|
|
| 13 |
def __init__(self, config):
|
| 14 |
super().__init__(config)
|
| 15 |
|
| 16 |
+
self.correction_cooldown = getattr(config, "correction_cooldown", 30)
|
| 17 |
self.num_new_tokens = 2
|
| 18 |
self.deletion_threshold = config.deletion_threshold if "deletion_threshold" in config else 0.7
|
| 19 |
|
|
|
|
| 112 |
):
|
| 113 |
"""
|
| 114 |
Custom generate method to orchestrate self-correction.
|
|
|
|
| 115 |
NOTE: This implementation currently only supports a batch size of 1.
|
| 116 |
"""
|
| 117 |
# Set the model to evaluation mode and cache instruction tokens.
|
|
|
|
| 122 |
generated_ids = input_ids
|
| 123 |
attention_mask = torch.ones_like(input_ids)
|
| 124 |
|
| 125 |
+
# Initialize a counter to track tokens since the last correction.
|
| 126 |
+
# Start it at the cooldown value to allow immediate correction if needed.
|
| 127 |
+
tokens_since_correction = self.correction_cooldown
|
| 128 |
+
|
| 129 |
# The first forward pass processes the prompt and gets the initial KV cache.
|
| 130 |
outputs = self(
|
| 131 |
input_ids=input_ids,
|
|
|
|
| 145 |
# Apply softmax to get hallucination probabilities.
|
| 146 |
hallucination_probs = F.softmax(hallucination_logits, dim=-1)
|
| 147 |
|
| 148 |
+
# Check if the cooldown period has passed.
|
| 149 |
+
can_correct = tokens_since_correction >= self.correction_cooldown
|
| 150 |
+
|
| 151 |
+
# Conditionally choose the next tokens based on the detector's output and the cooldown.
|
| 152 |
+
if can_correct and hallucination_probs[0, 1] > self.deletion_threshold:
|
| 153 |
current_tokens = self.rewrite_sentence_ids
|
| 154 |
+
tokens_since_correction = 0 # Reset the counter
|
| 155 |
+
elif can_correct and hallucination_probs[0, 2] > self.deletion_threshold:
|
| 156 |
current_tokens = self.rewrite_response_ids
|
| 157 |
+
tokens_since_correction = 0 # Reset the counter
|
| 158 |
else:
|
| 159 |
if temperature > 0.0:
|
| 160 |
scaled_logits = next_token_logits / temperature
|
|
|
|
| 162 |
current_tokens = torch.multinomial(probs, num_samples=1)
|
| 163 |
else:
|
| 164 |
current_tokens = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
|
| 165 |
+
|
| 166 |
+
# Increment the counter by the number of tokens just generated.
|
| 167 |
+
tokens_since_correction += current_tokens.shape[1]
|
| 168 |
|
| 169 |
generated_ids = torch.cat([generated_ids, current_tokens], dim=-1)
|
| 170 |
|