Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from transformers import GPT2Model, GPT2Tokenizer, GPT2Config | |
| from torch.utils.data import Dataset, DataLoader | |
| class GPT2ActivationHook: | |
| def __init__(self, model:GPT2Model): | |
| self.model = model | |
| self.activations: list[torch.Tensor] = [] | |
| self._hooks = [] | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| for block in self.model.h: | |
| hook = block.register_forward_hook(self._capture) | |
| self._hooks.append(hook) | |
| def _capture(self, module, input, output): | |
| self.activations.append(output[0]) | |
| def get_activations(self) -> torch.Tensor: | |
| stacked = torch.cat(self.activations, dim=-1) | |
| return stacked | |
| def clear(self): | |
| self.activations=[] | |
| def remove(self): | |
| for h in self._hooks: | |
| h.remove() | |
| self._hooks=[] | |
| class LinearProbe(nn.Module): | |
| def __init__(self, num_layers: int, hidden_dim:int): | |
| super().__init__() | |
| input_dim = num_layers * hidden_dim | |
| self.linear = nn.Linear(input_dim, 1) | |
| def forward(self, activations:torch.Tensor) -> torch.Tensor: | |
| logits = self.linear(activations) | |
| return logits.squeeze(-1) | |
| class ConstitutionalProbe(nn.Module): | |
| def __init__(self, gpt2_model_name: str = "gpt2"): | |
| super().__init__() | |
| self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name) | |
| for param in self.gpt2.parameters(): | |
| param.requires_grad = False | |
| config: GPT2Config = self.gpt2.config | |
| self.num_layers = config.n_layer | |
| self.hidden_dim = config.n_embd | |
| self.probe = LinearProbe(self.num_layers, self.hidden_dim) | |
| self.hook_manager = GPT2ActivationHook(self.gpt2) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask = None, | |
| ) -> torch.Tensor: | |
| self.hook_manager.clear() | |
| _ = self.gpt2(input_ids = input_ids, attention_mask = attention_mask) | |
| activations = self.hook_manager.get_activations() | |
| logits = self.probe(activations) | |
| return logits | |
| def trainable_parameters(self): | |
| return self.probe.parameters() | |