import torch import torch.nn as nn import numpy as np from transformers import GPT2Model, GPT2Tokenizer, GPT2Config from torch.utils.data import Dataset, DataLoader class GPT2ActivationHook: def __init__(self, model:GPT2Model): self.model = model self.activations: list[torch.Tensor] = [] self._hooks = [] self._register_hooks() def _register_hooks(self): for block in self.model.h: hook = block.register_forward_hook(self._capture) self._hooks.append(hook) def _capture(self, module, input, output): self.activations.append(output[0]) def get_activations(self) -> torch.Tensor: stacked = torch.cat(self.activations, dim=-1) return stacked def clear(self): self.activations=[] def remove(self): for h in self._hooks: h.remove() self._hooks=[] class LinearProbe(nn.Module): def __init__(self, num_layers: int, hidden_dim:int): super().__init__() input_dim = num_layers * hidden_dim self.linear = nn.Linear(input_dim, 1) def forward(self, activations:torch.Tensor) -> torch.Tensor: logits = self.linear(activations) return logits.squeeze(-1) class ConstitutionalProbe(nn.Module): def __init__(self, gpt2_model_name: str = "gpt2"): super().__init__() self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name) for param in self.gpt2.parameters(): param.requires_grad = False config: GPT2Config = self.gpt2.config self.num_layers = config.n_layer self.hidden_dim = config.n_embd self.probe = LinearProbe(self.num_layers, self.hidden_dim) self.hook_manager = GPT2ActivationHook(self.gpt2) def forward( self, input_ids: torch.Tensor, attention_mask = None, ) -> torch.Tensor: self.hook_manager.clear() _ = self.gpt2(input_ids = input_ids, attention_mask = attention_mask) activations = self.hook_manager.get_activations() logits = self.probe(activations) return logits def trainable_parameters(self): return self.probe.parameters()