cdpearlman commited on
Commit
455f0c2
·
1 Parent(s): 89704eb

Fix ablation: capture ablated attention module output to populate layer data

Browse files
Files changed (1) hide show
  1. utils/model_patterns.py +11 -7
utils/model_patterns.py CHANGED
@@ -337,9 +337,11 @@ def execute_forward_pass_with_head_ablation(model, tokenizer, prompt: str, confi
337
  def make_hook(mod_name: str):
338
  return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
339
 
340
- # Create head ablation hook
341
  def head_ablation_hook(module, input, output):
342
- """Zero out specific attention heads in the output."""
 
 
343
  if isinstance(output, tuple):
344
  # Attention modules typically return (hidden_states, attention_weights, ...)
345
  hidden_states = output[0] # [batch, seq_len, hidden_dim]
@@ -369,12 +371,14 @@ def execute_forward_pass_with_head_ablation(model, tokenizer, prompt: str, confi
369
 
370
  # Reconstruct output tuple
371
  if len(output) > 1:
372
- return (ablated_hidden,) + output[1:]
373
  else:
374
- return (ablated_hidden,)
375
- else:
376
- # If output is not a tuple, just return as is (shouldn't happen for attention)
377
- return output
 
 
378
 
379
  # Register hooks
380
  hooks = []
 
337
  def make_hook(mod_name: str):
338
  return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
339
 
340
+ # Create head ablation hook that both ablates and captures
341
  def head_ablation_hook(module, input, output):
342
+ """Zero out specific attention heads in the output AND capture it."""
343
+ ablated_output = output # Default to original output
344
+
345
  if isinstance(output, tuple):
346
  # Attention modules typically return (hidden_states, attention_weights, ...)
347
  hidden_states = output[0] # [batch, seq_len, hidden_dim]
 
371
 
372
  # Reconstruct output tuple
373
  if len(output) > 1:
374
+ ablated_output = (ablated_hidden,) + output[1:]
375
  else:
376
+ ablated_output = (ablated_hidden,)
377
+
378
+ # Capture the ablated output (CRITICAL: this was missing!)
379
+ captured.update({target_attention_module: {"output": safe_to_serializable(ablated_output)}})
380
+
381
+ return ablated_output
382
 
383
  # Register hooks
384
  hooks = []