mksethi commited on
Commit
f0533ae
·
verified ·
1 Parent(s): a886b64

Add config + custom code for Query2SAE

Browse files
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "query2sae",
3
+ "backbone_name": "gpt2",
4
+ "head_hidden_dim": 128,
5
+ "sae_dim": 1024,
6
+ "auto_map": {
7
+ "AutoConfig": "my_package.my_configuration.Query2SAEConfig",
8
+ "AutoModel": "my_package.my_modeling.Query2SAEModel"
9
+ }
10
+ }
11
+
my_package/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # keeps the package importable on the Hub
2
+ from .my_configuration import Query2SAEConfig
3
+ from .my_modeling import Query2SAEModel
my_package/my_configuration.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class Query2SAEConfig(PretrainedConfig):
4
+ model_type = "query2sae"
5
+
6
+ def __init__(
7
+ self,
8
+ backbone_name: str = "gpt2",
9
+ head_hidden_dim: int = 128,
10
+ sae_dim: int = 1024, # <-- set this to YOUR actual SAE feature count
11
+ **kwargs,
12
+ ):
13
+ super().__init__(**kwargs)
14
+ self.backbone_name = backbone_name
15
+ self.head_hidden_dim = int(head_hidden_dim)
16
+ self.sae_dim = int(sae_dim)
my_package/my_modeling.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, GPT2Config, GPT2Model
4
+ from .my_configuration import Query2SAEConfig
5
+
6
+ class Query2SAEModel(PreTrainedModel):
7
+ """
8
+ Hugging Face-compatible wrapper around your Query2SAE.
9
+ - Freezes the GPT-2 backbone
10
+ - Adds a small MLP head to predict SAE features
11
+ - Saves/loads with save_pretrained()/from_pretrained()
12
+ """
13
+ config_class = Query2SAEConfig
14
+ base_model_prefix = "query2sae"
15
+
16
+ def __init__(self, config: Query2SAEConfig):
17
+ super().__init__(config)
18
+
19
+ # Build GPT-2 backbone WITHOUT downloading weights (weights are loaded by from_pretrained)
20
+ gpt2_cfg = GPT2Config.from_pretrained(config.backbone_name)
21
+ self.backbone = GPT2Model(gpt2_cfg)
22
+
23
+ # Freeze backbone parameters
24
+ for p in self.backbone.parameters():
25
+ p.requires_grad = False
26
+
27
+ # Head maps hidden_size -> head_hidden_dim -> sae_dim
28
+ self.head = nn.Sequential(
29
+ nn.Linear(self.backbone.config.hidden_size, config.head_hidden_dim),
30
+ nn.ReLU(),
31
+ nn.Linear(config.head_hidden_dim, config.sae_dim),
32
+ )
33
+
34
+ # Initialize head weights the HF way
35
+ self.post_init()
36
+
37
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
38
+ # no grad through backbone (keeps it frozen and faster)
39
+ with torch.no_grad():
40
+ out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
41
+ last_hidden = out.last_hidden_state[:, -1, :] # [B, hidden_size]
42
+ logits = self.head(last_hidden) # [B, sae_dim]
43
+ return {"logits": logits, "last_hidden_state": out.last_hidden_state}
44
+
45
+ # Optional helpers for HF-style naming consistency
46
+ def get_input_embeddings(self):
47
+ return self.backbone.get_input_embeddings()
48
+
49
+ def set_input_embeddings(self, value):
50
+ return self.backbone.set_input_embeddings(value)