Spaces:
Sleeping
Sleeping
Commit ·
455f0c2
1
Parent(s): 89704eb
Fix ablation: capture ablated attention module output to populate layer data
Browse files- utils/model_patterns.py +11 -7
utils/model_patterns.py
CHANGED
|
@@ -337,9 +337,11 @@ def execute_forward_pass_with_head_ablation(model, tokenizer, prompt: str, confi
|
|
| 337 |
def make_hook(mod_name: str):
|
| 338 |
return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
|
| 339 |
|
| 340 |
-
# Create head ablation hook
|
| 341 |
def head_ablation_hook(module, input, output):
|
| 342 |
-
"""Zero out specific attention heads in the output."""
|
|
|
|
|
|
|
| 343 |
if isinstance(output, tuple):
|
| 344 |
# Attention modules typically return (hidden_states, attention_weights, ...)
|
| 345 |
hidden_states = output[0] # [batch, seq_len, hidden_dim]
|
|
@@ -369,12 +371,14 @@ def execute_forward_pass_with_head_ablation(model, tokenizer, prompt: str, confi
|
|
| 369 |
|
| 370 |
# Reconstruct output tuple
|
| 371 |
if len(output) > 1:
|
| 372 |
-
|
| 373 |
else:
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
|
|
|
|
|
|
| 378 |
|
| 379 |
# Register hooks
|
| 380 |
hooks = []
|
|
|
|
| 337 |
def make_hook(mod_name: str):
|
| 338 |
return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
|
| 339 |
|
| 340 |
+
# Create head ablation hook that both ablates and captures
|
| 341 |
def head_ablation_hook(module, input, output):
|
| 342 |
+
"""Zero out specific attention heads in the output AND capture it."""
|
| 343 |
+
ablated_output = output # Default to original output
|
| 344 |
+
|
| 345 |
if isinstance(output, tuple):
|
| 346 |
# Attention modules typically return (hidden_states, attention_weights, ...)
|
| 347 |
hidden_states = output[0] # [batch, seq_len, hidden_dim]
|
|
|
|
| 371 |
|
| 372 |
# Reconstruct output tuple
|
| 373 |
if len(output) > 1:
|
| 374 |
+
ablated_output = (ablated_hidden,) + output[1:]
|
| 375 |
else:
|
| 376 |
+
ablated_output = (ablated_hidden,)
|
| 377 |
+
|
| 378 |
+
# Capture the ablated output (CRITICAL: this was missing!)
|
| 379 |
+
captured.update({target_attention_module: {"output": safe_to_serializable(ablated_output)}})
|
| 380 |
+
|
| 381 |
+
return ablated_output
|
| 382 |
|
| 383 |
# Register hooks
|
| 384 |
hooks = []
|