botverse commited on
Commit
db7c8e4
·
1 Parent(s): 8f20917

added hf url

Browse files
Files changed (1) hide show
  1. models/inference_memory_wrapper.py +45 -17
models/inference_memory_wrapper.py CHANGED
@@ -94,6 +94,7 @@ class InferenceMemoryWrapper(PreTrainedModel):
94
  def apply_surprise_update(self):
95
  """ Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
96
  if self.memory_buffer.grad is None:
 
97
  return
98
 
99
  # Ensure surprise_state is on the same device and dtype
@@ -182,12 +183,14 @@ class InferenceMemoryWrapper(PreTrainedModel):
182
  print("Warning: update_rule='surprise' requires use_memory=True.")
183
  update_rule = 'none'
184
 
 
 
185
  if update_rule == 'surprise':
186
  self.memory_buffer.requires_grad_(True)
 
187
  else:
188
- # Ensure no grads are computed if not needed
189
- # Note: Llama part is already frozen and in eval mode
190
- pass # No specific action needed if not surprise
191
 
192
  bsz, seq_len_start = input_ids.shape
193
  device = input_ids.device
@@ -278,16 +281,19 @@ class InferenceMemoryWrapper(PreTrainedModel):
278
  # We need past_key_values AND not be doing surprise update AND base model supports caching
279
  use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache
280
 
281
- outputs = self.llama(
282
- input_ids=cur_input_ids_for_llama, # None if using embeds
283
- inputs_embeds=model_inputs_embeds,
284
- attention_mask=current_mask, # Pass the correctly shaped mask for this step
285
- position_ids=position_ids, # Pass adjusted position IDs
286
- past_key_values=past_key_values,
287
- use_cache=use_kv_cache_this_step,
288
- output_hidden_states=True, # Needed for query/target/update
289
- return_dict=True,
290
- )
 
 
 
291
 
292
  # --- Associative Loss Calculation (if surprise update) ---
293
  if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
@@ -299,12 +305,33 @@ class InferenceMemoryWrapper(PreTrainedModel):
299
  # pred_repr comes from retrieve_memory, should already match buffer dtype
300
  pred_repr = retrieved_mem.squeeze(1) # (B, C)
301
 
302
- assoc_loss = F.mse_loss(pred_repr, target_repr.detach())
 
 
 
 
 
 
 
 
 
 
303
 
304
  if self.memory_buffer.grad is not None:
 
305
  self.memory_buffer.grad.zero_()
306
- assoc_loss.backward() # Compute grads for memory_buffer
307
- self.apply_surprise_update() # Apply update and zero grad
 
 
 
 
 
 
 
 
 
 
308
 
309
  # --- Standard Generation Logic ---
310
  # Get logits for the very last position in the output sequence (corresponds to the token we just fed in)
@@ -362,7 +389,8 @@ class InferenceMemoryWrapper(PreTrainedModel):
362
  if eos_token_id is not None and (next_token == eos_token_id).all():
363
  break
364
 
365
- # self.eval() # Already in eval mode if llama is frozen
 
366
 
367
  return generated_ids
368
 
 
94
  def apply_surprise_update(self):
95
  """ Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
96
  if self.memory_buffer.grad is None:
97
+ print("DEBUG: apply_surprise_update called but memory_buffer.grad is None.")
98
  return
99
 
100
  # Ensure surprise_state is on the same device and dtype
 
183
  print("Warning: update_rule='surprise' requires use_memory=True.")
184
  update_rule = 'none'
185
 
186
+ # Ensure buffer requires grad only when needed
187
+ original_requires_grad = self.memory_buffer.requires_grad
188
  if update_rule == 'surprise':
189
  self.memory_buffer.requires_grad_(True)
190
+ print(f"DEBUG: Set memory_buffer.requires_grad = {self.memory_buffer.requires_grad}")
191
  else:
192
+ self.memory_buffer.requires_grad_(False)
193
+
 
194
 
195
  bsz, seq_len_start = input_ids.shape
196
  device = input_ids.device
 
281
  # We need past_key_values AND not be doing surprise update AND base model supports caching
282
  use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache
283
 
284
+ # Ensure context manager enables grads only when needed
285
+ context = torch.enable_grad() if update_rule == 'surprise' else torch.no_grad()
286
+ with context:
287
+ outputs = self.llama(
288
+ input_ids=cur_input_ids_for_llama, # None if using embeds
289
+ inputs_embeds=model_inputs_embeds,
290
+ attention_mask=current_mask, # Pass the correctly shaped mask for this step
291
+ position_ids=position_ids, # Pass adjusted position IDs
292
+ past_key_values=past_key_values,
293
+ use_cache=use_kv_cache_this_step,
294
+ output_hidden_states=True, # Needed for query/target/update
295
+ return_dict=True,
296
+ )
297
 
298
  # --- Associative Loss Calculation (if surprise update) ---
299
  if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
 
305
  # pred_repr comes from retrieve_memory, should already match buffer dtype
306
  pred_repr = retrieved_mem.squeeze(1) # (B, C)
307
 
308
+ # --- DEBUG PRINTS ---
309
+ print(f"\n--- Surprise Update Debug (Step {step}) ---")
310
+ print(f" memory_buffer requires_grad: {self.memory_buffer.requires_grad}")
311
+ print(f" retrieved_mem requires_grad: {retrieved_mem.requires_grad if retrieved_mem is not None else 'N/A'}")
312
+ print(f" pred_repr requires_grad: {pred_repr.requires_grad if pred_repr is not None else 'N/A'}")
313
+ print(f" target_repr requires_grad: {target_repr.requires_grad}") # Should be False due to .detach() below
314
+ # --- END DEBUG ---
315
+
316
+ assoc_loss = F.mse_loss(pred_repr, target_repr.detach()) # TARGET IS DETACHED
317
+ print(f" assoc_loss: {assoc_loss.item():.4f}, requires_grad: {assoc_loss.requires_grad}")
318
+
319
 
320
  if self.memory_buffer.grad is not None:
321
+ print(" Zeroing existing memory_buffer gradient.")
322
  self.memory_buffer.grad.zero_()
323
+
324
+ if assoc_loss.requires_grad:
325
+ print(" Calling assoc_loss.backward()")
326
+ assoc_loss.backward() # Compute grads for memory_buffer
327
+ print(f" memory_buffer.grad is None after backward: {self.memory_buffer.grad is None}")
328
+ if self.memory_buffer.grad is not None:
329
+ print(f" memory_buffer.grad norm: {torch.norm(self.memory_buffer.grad).item():.4f}")
330
+ self.apply_surprise_update() # Apply update and zero grad
331
+ else:
332
+ print(" ERROR: assoc_loss does not require grad! Skipping backward and update.")
333
+ print("--- End Surprise Update Debug ---")
334
+
335
 
336
  # --- Standard Generation Logic ---
337
  # Get logits for the very last position in the output sequence (corresponds to the token we just fed in)
 
389
  if eos_token_id is not None and (next_token == eos_token_id).all():
390
  break
391
 
392
+ # Restore original requires_grad state
393
+ self.memory_buffer.requires_grad_(original_requires_grad)
394
 
395
  return generated_ids
396