File size: 2,241 Bytes
6578134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import numpy as np
from transformers import GPT2Model, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader


class GPT2ActivationHook:
    def __init__(self, model:GPT2Model):
        self.model = model
        self.activations: list[torch.Tensor] = []
        self._hooks = []
        self._register_hooks()

    def _register_hooks(self):
        for block in self.model.h:
            hook = block.register_forward_hook(self._capture)
            self._hooks.append(hook)

    def _capture(self, module, input, output):
        self.activations.append(output[0])

    def get_activations(self) -> torch.Tensor:
        stacked = torch.cat(self.activations, dim=-1)
        return stacked

    def clear(self):
        self.activations=[]

    def remove(self):
        for h in self._hooks:
            h.remove()
        self._hooks=[]
    

class LinearProbe(nn.Module):
    def __init__(self, num_layers: int, hidden_dim:int):
        super().__init__()
        input_dim = num_layers * hidden_dim
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, activations:torch.Tensor) -> torch.Tensor:
        logits = self.linear(activations)
        return logits.squeeze(-1)
    
class ConstitutionalProbe(nn.Module):
    def __init__(self, gpt2_model_name: str = "gpt2"):
        super().__init__()
        self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name)

        for param in self.gpt2.parameters():
            param.requires_grad = False
        
        config: GPT2Config = self.gpt2.config
        self.num_layers = config.n_layer
        self.hidden_dim = config.n_embd

        self.probe = LinearProbe(self.num_layers, self.hidden_dim)
        self.hook_manager = GPT2ActivationHook(self.gpt2)

    def forward(
            self,
            input_ids: torch.Tensor,
            attention_mask = None,
    ) -> torch.Tensor:
        self.hook_manager.clear()

        _ = self.gpt2(input_ids = input_ids, attention_mask = attention_mask)

        activations = self.hook_manager.get_activations()

        logits = self.probe(activations)
        return logits
    
    def trainable_parameters(self):
        return self.probe.parameters()