File size: 1,291 Bytes
3f2301c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a011b9
3f2301c
2a011b9
72d268b
3f2301c
2a011b9
72d268b
3f2301c
2a011b9
 
72d268b
3f2301c
2a011b9
72d268b
 
3f2301c
2a011b9
72d268b
 
2a011b9
72d268b
3f2301c
2a011b9
72d268b
 
3f2301c
2a011b9
3f2301c
72d268b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from transformer_lens import HookedTransformer

def load_gpt2():
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    return model, tokenizer

def generate_text(prompt, max_length=50):
    model, tokenizer = load_gpt2()
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model.generate(**inputs, max_length=max_length)
    text = tokenizer.decode(output[0], skip_special_tokens=True)
    return text

def run_activation_patching(prompt):
    # Load HookedTransformer
    model = HookedTransformer.from_pretrained("gpt2-small")

    # Tokenize the text
    tokens = model.to_tokens(prompt)

    # Dict to store activations
    activations = {}

    # Hook function
    def hook_fn(value, hook):
        activations[hook.name] = value.detach().cpu().numpy()

    # Register hooks for all MLP post-activations
    hook_handles = []
    for i in range(model.cfg.n_layers):
        hook_name = f"blocks.{i}.mlp.hook_post"
        handle = model.add_hook(hook_name, hook_fn)
        hook_handles.append(handle)

    # Run the model forward
    _ = model(tokens)

    # Remove all hooks
    for h in hook_handles:
        h.remove()

    return activations