added hf url
Browse files
models/inference_memory_wrapper.py
CHANGED
|
@@ -94,6 +94,7 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 94 |
def apply_surprise_update(self):
|
| 95 |
""" Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
|
| 96 |
if self.memory_buffer.grad is None:
|
|
|
|
| 97 |
return
|
| 98 |
|
| 99 |
# Ensure surprise_state is on the same device and dtype
|
|
@@ -182,12 +183,14 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 182 |
print("Warning: update_rule='surprise' requires use_memory=True.")
|
| 183 |
update_rule = 'none'
|
| 184 |
|
|
|
|
|
|
|
| 185 |
if update_rule == 'surprise':
|
| 186 |
self.memory_buffer.requires_grad_(True)
|
|
|
|
| 187 |
else:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
pass # No specific action needed if not surprise
|
| 191 |
|
| 192 |
bsz, seq_len_start = input_ids.shape
|
| 193 |
device = input_ids.device
|
|
@@ -278,16 +281,19 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 278 |
# We need past_key_values AND not be doing surprise update AND base model supports caching
|
| 279 |
use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
# --- Associative Loss Calculation (if surprise update) ---
|
| 293 |
if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
|
|
@@ -299,12 +305,33 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 299 |
# pred_repr comes from retrieve_memory, should already match buffer dtype
|
| 300 |
pred_repr = retrieved_mem.squeeze(1) # (B, C)
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
if self.memory_buffer.grad is not None:
|
|
|
|
| 305 |
self.memory_buffer.grad.zero_()
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
# --- Standard Generation Logic ---
|
| 310 |
# Get logits for the very last position in the output sequence (corresponds to the token we just fed in)
|
|
@@ -362,7 +389,8 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 362 |
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 363 |
break
|
| 364 |
|
| 365 |
-
#
|
|
|
|
| 366 |
|
| 367 |
return generated_ids
|
| 368 |
|
|
|
|
| 94 |
def apply_surprise_update(self):
|
| 95 |
""" Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
|
| 96 |
if self.memory_buffer.grad is None:
|
| 97 |
+
print("DEBUG: apply_surprise_update called but memory_buffer.grad is None.")
|
| 98 |
return
|
| 99 |
|
| 100 |
# Ensure surprise_state is on the same device and dtype
|
|
|
|
| 183 |
print("Warning: update_rule='surprise' requires use_memory=True.")
|
| 184 |
update_rule = 'none'
|
| 185 |
|
| 186 |
+
# Ensure buffer requires grad only when needed
|
| 187 |
+
original_requires_grad = self.memory_buffer.requires_grad
|
| 188 |
if update_rule == 'surprise':
|
| 189 |
self.memory_buffer.requires_grad_(True)
|
| 190 |
+
print(f"DEBUG: Set memory_buffer.requires_grad = {self.memory_buffer.requires_grad}")
|
| 191 |
else:
|
| 192 |
+
self.memory_buffer.requires_grad_(False)
|
| 193 |
+
|
|
|
|
| 194 |
|
| 195 |
bsz, seq_len_start = input_ids.shape
|
| 196 |
device = input_ids.device
|
|
|
|
| 281 |
# We need past_key_values AND not be doing surprise update AND base model supports caching
|
| 282 |
use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache
|
| 283 |
|
| 284 |
+
# Ensure context manager enables grads only when needed
|
| 285 |
+
context = torch.enable_grad() if update_rule == 'surprise' else torch.no_grad()
|
| 286 |
+
with context:
|
| 287 |
+
outputs = self.llama(
|
| 288 |
+
input_ids=cur_input_ids_for_llama, # None if using embeds
|
| 289 |
+
inputs_embeds=model_inputs_embeds,
|
| 290 |
+
attention_mask=current_mask, # Pass the correctly shaped mask for this step
|
| 291 |
+
position_ids=position_ids, # Pass adjusted position IDs
|
| 292 |
+
past_key_values=past_key_values,
|
| 293 |
+
use_cache=use_kv_cache_this_step,
|
| 294 |
+
output_hidden_states=True, # Needed for query/target/update
|
| 295 |
+
return_dict=True,
|
| 296 |
+
)
|
| 297 |
|
| 298 |
# --- Associative Loss Calculation (if surprise update) ---
|
| 299 |
if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
|
|
|
|
| 305 |
# pred_repr comes from retrieve_memory, should already match buffer dtype
|
| 306 |
pred_repr = retrieved_mem.squeeze(1) # (B, C)
|
| 307 |
|
| 308 |
+
# --- DEBUG PRINTS ---
|
| 309 |
+
print(f"\n--- Surprise Update Debug (Step {step}) ---")
|
| 310 |
+
print(f" memory_buffer requires_grad: {self.memory_buffer.requires_grad}")
|
| 311 |
+
print(f" retrieved_mem requires_grad: {retrieved_mem.requires_grad if retrieved_mem is not None else 'N/A'}")
|
| 312 |
+
print(f" pred_repr requires_grad: {pred_repr.requires_grad if pred_repr is not None else 'N/A'}")
|
| 313 |
+
print(f" target_repr requires_grad: {target_repr.requires_grad}") # Should be False due to .detach() below
|
| 314 |
+
# --- END DEBUG ---
|
| 315 |
+
|
| 316 |
+
assoc_loss = F.mse_loss(pred_repr, target_repr.detach()) # TARGET IS DETACHED
|
| 317 |
+
print(f" assoc_loss: {assoc_loss.item():.4f}, requires_grad: {assoc_loss.requires_grad}")
|
| 318 |
+
|
| 319 |
|
| 320 |
if self.memory_buffer.grad is not None:
|
| 321 |
+
print(" Zeroing existing memory_buffer gradient.")
|
| 322 |
self.memory_buffer.grad.zero_()
|
| 323 |
+
|
| 324 |
+
if assoc_loss.requires_grad:
|
| 325 |
+
print(" Calling assoc_loss.backward()")
|
| 326 |
+
assoc_loss.backward() # Compute grads for memory_buffer
|
| 327 |
+
print(f" memory_buffer.grad is None after backward: {self.memory_buffer.grad is None}")
|
| 328 |
+
if self.memory_buffer.grad is not None:
|
| 329 |
+
print(f" memory_buffer.grad norm: {torch.norm(self.memory_buffer.grad).item():.4f}")
|
| 330 |
+
self.apply_surprise_update() # Apply update and zero grad
|
| 331 |
+
else:
|
| 332 |
+
print(" ERROR: assoc_loss does not require grad! Skipping backward and update.")
|
| 333 |
+
print("--- End Surprise Update Debug ---")
|
| 334 |
+
|
| 335 |
|
| 336 |
# --- Standard Generation Logic ---
|
| 337 |
# Get logits for the very last position in the output sequence (corresponds to the token we just fed in)
|
|
|
|
| 389 |
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 390 |
break
|
| 391 |
|
| 392 |
+
# Restore original requires_grad state
|
| 393 |
+
self.memory_buffer.requires_grad_(original_requires_grad)
|
| 394 |
|
| 395 |
return generated_ids
|
| 396 |
|