mksethi commited on
Commit
6dc7202
·
verified ·
1 Parent(s): f0533ae

Add config + custom code for Query2SAE

Browse files
Files changed (3) hide show
  1. config.json +9 -9
  2. configuration_query2sae.py +16 -0
  3. model_query2sae.py +37 -0
config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
- "model_type": "query2sae",
3
- "backbone_name": "gpt2",
4
- "head_hidden_dim": 128,
5
- "sae_dim": 1024,
6
- "auto_map": {
7
- "AutoConfig": "my_package.my_configuration.Query2SAEConfig",
8
- "AutoModel": "my_package.my_modeling.Query2SAEModel"
9
- }
10
  }
11
-
 
1
  {
2
+ "model_type": "query2sae",
3
+ "backbone_name": "gpt2",
4
+ "head_hidden_dim": 128,
5
+ "sae_dim": 1024,
6
+ "architectures": ["Query2SAEModel"],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_query2sae.Query2SAEConfig",
9
+ "AutoModel": "modeling_query2sae.Query2SAEModel"
10
  }
11
+ }
configuration_query2sae.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class Query2SAEConfig(PretrainedConfig):
4
+ model_type = "query2sae"
5
+
6
+ def __init__(
7
+ self,
8
+ backbone_name: str = "gpt2",
9
+ head_hidden_dim: int = 128,
10
+ sae_dim: int = 1024, # <-- set this to YOUR real SAE dim
11
+ **kwargs,
12
+ ):
13
+ super().__init__(**kwargs)
14
+ self.backbone_name = backbone_name
15
+ self.head_hidden_dim = int(head_hidden_dim)
16
+ self.sae_dim = int(sae_dim)
model_query2sae.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, GPT2Config, GPT2Model
4
+ from configuration_query2sae import Query2SAEConfig
5
+
6
+ class Query2SAEModel(PreTrainedModel):
7
+ """
8
+ HF-compatible wrapper for your Query2SAE:
9
+ - GPT-2 backbone is frozen
10
+ - MLP head maps hidden -> SAE space
11
+ """
12
+ config_class = Query2SAEConfig
13
+ base_model_prefix = "query2sae"
14
+
15
+ def __init__(self, config: Query2SAEConfig):
16
+ super().__init__(config)
17
+ # Build GPT-2 backbone (weights will be loaded by from_pretrained via state_dict)
18
+ gpt2_cfg = GPT2Config.from_pretrained(config.backbone_name)
19
+ self.backbone = GPT2Model(gpt2_cfg)
20
+
21
+ for p in self.backbone.parameters():
22
+ p.requires_grad = False
23
+
24
+ self.head = nn.Sequential(
25
+ nn.Linear(self.backbone.config.hidden_size, config.head_hidden_dim),
26
+ nn.ReLU(),
27
+ nn.Linear(config.head_hidden_dim, config.sae_dim),
28
+ )
29
+
30
+ self.post_init()
31
+
32
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
33
+ with torch.no_grad():
34
+ out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
35
+ last_hidden = out.last_hidden_state[:, -1, :]
36
+ logits = self.head(last_hidden)
37
+ return {"logits": logits}