CC_linear_probe / linear_probe.py
urbas's picture
Upload 2 files
6578134 verified
import torch
import torch.nn as nn
import numpy as np
from transformers import GPT2Model, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
class GPT2ActivationHook:
def __init__(self, model:GPT2Model):
self.model = model
self.activations: list[torch.Tensor] = []
self._hooks = []
self._register_hooks()
def _register_hooks(self):
for block in self.model.h:
hook = block.register_forward_hook(self._capture)
self._hooks.append(hook)
def _capture(self, module, input, output):
self.activations.append(output[0])
def get_activations(self) -> torch.Tensor:
stacked = torch.cat(self.activations, dim=-1)
return stacked
def clear(self):
self.activations=[]
def remove(self):
for h in self._hooks:
h.remove()
self._hooks=[]
class LinearProbe(nn.Module):
def __init__(self, num_layers: int, hidden_dim:int):
super().__init__()
input_dim = num_layers * hidden_dim
self.linear = nn.Linear(input_dim, 1)
def forward(self, activations:torch.Tensor) -> torch.Tensor:
logits = self.linear(activations)
return logits.squeeze(-1)
class ConstitutionalProbe(nn.Module):
def __init__(self, gpt2_model_name: str = "gpt2"):
super().__init__()
self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name)
for param in self.gpt2.parameters():
param.requires_grad = False
config: GPT2Config = self.gpt2.config
self.num_layers = config.n_layer
self.hidden_dim = config.n_embd
self.probe = LinearProbe(self.num_layers, self.hidden_dim)
self.hook_manager = GPT2ActivationHook(self.gpt2)
def forward(
self,
input_ids: torch.Tensor,
attention_mask = None,
) -> torch.Tensor:
self.hook_manager.clear()
_ = self.gpt2(input_ids = input_ids, attention_mask = attention_mask)
activations = self.hook_manager.get_activations()
logits = self.probe(activations)
return logits
def trainable_parameters(self):
return self.probe.parameters()