# model.py import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast import numpy as np class ModelWrapper: def __init__(self, model_name="gpt2", device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name) self.model = GPT2LMHeadModel.from_pretrained( model_name, output_hidden_states=True ).to(self.device) self.model.eval() # ----------------------------------------------------- # TEXT GENERATION # ----------------------------------------------------- def generate_text(self, prompt, max_length=50, top_k=10): inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): output = self.model.generate( **inputs, max_length=len(inputs['input_ids'][0]) + max_length, do_sample=True, top_k=top_k, pad_token_id=self.tokenizer.eos_token_id ) return self.tokenizer.decode(output[0], skip_special_tokens=True) # ----------------------------------------------------- # HIDDEN STATES # ----------------------------------------------------- def _get_hidden_states(self, prompt): inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.model(**inputs) return out.hidden_states # ----------------------------------------------------- # ACTIVATION PATCHING (LAYER IMPORTANCE) # ----------------------------------------------------- def layer_importance(self, prompt, experiment_type="story_continuation"): """ Computes a simple proxy for activation patching: For each transformer block: - Run GPT-2 normally - Run GPT-2 with that layer's hidden output zeroed - Compute difference in next-token logits Returns a list of importance scores normalized 0-1. """ # Tokenize input inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # Baseline forward pass with torch.no_grad(): out = self.model(**inputs, output_hidden_states=True) baseline_logits = out.logits[0, -1, :].cpu().numpy() # GPT-2 has 12 layers (gpt2-base) n_layers = len(self.model.transformer.h) scores = [] for layer_idx in range(n_layers): # --------------------------- # CORRECTED HOOK # --------------------------- def hook(module, inp, outp): """ GPT-2 block returns: outp = (hidden_states, present) We must keep structure intact. """ hidden, present = outp hidden_zero = torch.zeros_like(hidden) return (hidden_zero, present) # Register hook handle = self.model.transformer.h[layer_idx].register_forward_hook(hook) # Patched forward pass with torch.no_grad(): out2 = self.model(**inputs) logits2 = out2.logits[0, -1, :].cpu().numpy() # L1 difference diff = np.sum(np.abs(baseline_logits - logits2)) scores.append(float(diff)) handle.remove() # Normalize 0–1 arr = np.array(scores) if arr.max() > 0: arr = (arr - arr.min()) / (arr.max() - arr.min()) else: arr = arr * 0.0 return arr.tolist()