AmnaHassan commited on
Commit
72d268b
·
verified ·
1 Parent(s): 2a011b9

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +11 -9
model_utils.py CHANGED
@@ -18,27 +18,29 @@ def run_activation_patching(prompt):
18
  # Load HookedTransformer
19
  model = HookedTransformer.from_pretrained("gpt2-small")
20
 
21
- # Tokenize
22
  tokens = model.to_tokens(prompt)
23
 
24
- # Store activations here
25
  activations = {}
26
 
27
  # Hook function
28
- def hook_fn_hook(value, hook):
29
  activations[hook.name] = value.detach().cpu().numpy()
30
 
31
- # Register hooks using new `.hooks` API
32
- hooks = []
33
  for i in range(model.cfg.n_layers):
34
  hook_name = f"blocks.{i}.mlp.hook_post"
35
- hooks.append(model.hooks.add_hook(hook_name, hook_fn_hook))
 
36
 
37
- # Run model forward
38
  _ = model(tokens)
39
 
40
- # Remove hooks
41
- for h in hooks:
42
  h.remove()
43
 
44
  return activations
 
 
18
  # Load HookedTransformer
19
  model = HookedTransformer.from_pretrained("gpt2-small")
20
 
21
+ # Tokenize the text
22
  tokens = model.to_tokens(prompt)
23
 
24
+ # Dict to store activations
25
  activations = {}
26
 
27
  # Hook function
28
+ def hook_fn(value, hook):
29
  activations[hook.name] = value.detach().cpu().numpy()
30
 
31
+ # Register hooks for all MLP post-activations
32
+ hook_handles = []
33
  for i in range(model.cfg.n_layers):
34
  hook_name = f"blocks.{i}.mlp.hook_post"
35
+ handle = model.add_hook(hook_name, hook_fn)
36
+ hook_handles.append(handle)
37
 
38
+ # Run the model forward
39
  _ = model(tokens)
40
 
41
+ # Remove all hooks
42
+ for h in hook_handles:
43
  h.remove()
44
 
45
  return activations
46
+