Simple_Calculator_Lecture / model_utils.py
AmnaHassan's picture
Update model_utils.py
72d268b verified
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from transformer_lens import HookedTransformer
def load_gpt2():
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
return model, tokenizer
def generate_text(prompt, max_length=50):
model, tokenizer = load_gpt2()
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(**inputs, max_length=max_length)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text
def run_activation_patching(prompt):
# Load HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small")
# Tokenize the text
tokens = model.to_tokens(prompt)
# Dict to store activations
activations = {}
# Hook function
def hook_fn(value, hook):
activations[hook.name] = value.detach().cpu().numpy()
# Register hooks for all MLP post-activations
hook_handles = []
for i in range(model.cfg.n_layers):
hook_name = f"blocks.{i}.mlp.hook_post"
handle = model.add_hook(hook_name, hook_fn)
hook_handles.append(handle)
# Run the model forward
_ = model(tokens)
# Remove all hooks
for h in hook_handles:
h.remove()
return activations