File size: 1,336 Bytes
6dc7202
 
 
2bb16ab
6dc7202
 
 
2bb16ab
6dc7202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from transformers import PreTrainedModel, GPT2Config, GPT2Model
from .configuration_query2sae import Query2SAEConfig  # <- relative import

class Query2SAEModel(PreTrainedModel):
    """
    HF-compatible wrapper for Query2SAE:
    - GPT-2 backbone is frozen
    - MLP head maps hidden -> SAE space
    """
    config_class = Query2SAEConfig
    base_model_prefix = "query2sae"

    def __init__(self, config: Query2SAEConfig):
        super().__init__(config)
        # Build GPT-2 backbone (weights will be loaded by from_pretrained via state_dict)
        gpt2_cfg = GPT2Config.from_pretrained(config.backbone_name)
        self.backbone = GPT2Model(gpt2_cfg)

        for p in self.backbone.parameters():
            p.requires_grad = False

        self.head = nn.Sequential(
            nn.Linear(self.backbone.config.hidden_size, config.head_hidden_dim),
            nn.ReLU(),
            nn.Linear(config.head_hidden_dim, config.sae_dim),
        )

        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        with torch.no_grad():
            out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state[:, -1, :]
        logits = self.head(last_hidden)
        return {"logits": logits}