KokosDev commited on
Commit
2ab7f0b
·
1 Parent(s): bc702e3

Add YAML metadata to README

Browse files
Files changed (1) hide show
  1. README.md +35 -18
README.md CHANGED
@@ -1,3 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # LLaVA-1.5-7B Cross-Layer Transcoders (CLTs)
2
 
3
  ## Overview
@@ -15,7 +33,7 @@ This repository contains **Cross-Layer Transcoders (CLTs)** trained on [llava-hf
15
 
16
  ## Architecture
17
 
18
- ```
19
  Input (MLP hidden state): [batch, seq_len, 4096]
20
 
21
  Transcoder Encoder: LayerNorm + Linear(4096 → 8192) + ReLU
@@ -25,7 +43,7 @@ Input (MLP hidden state): [batch, seq_len, 4096]
25
  Transcoder Decoder: Linear(8192 → 4096)
26
 
27
  Output (MLP reconstruction): [batch, seq_len, 4096]
28
- ```
29
 
30
  **Parameters per layer:**
31
  - Hidden dim: 4096
@@ -37,7 +55,7 @@ Output (MLP reconstruction): [batch, seq_len, 4096]
37
 
38
  ## Training Details
39
 
40
- - **Model**: `llava-hf/llava-1.5-7b-hf`
41
  - **Dataset**: ~45K multimodal samples (Flickr30K + instruction tasks)
42
  - **Steps per layer**: 5,000
43
  - **Learning rate**: 3e-4 (AdamW)
@@ -61,24 +79,24 @@ Output (MLP reconstruction): [batch, seq_len, 4096]
61
 
62
  Each layer has two files:
63
 
64
- ### 1. `transcoder_L{layer}.pt`
65
  Contains the trained transcoder model and training metadata.
66
 
67
- ```python
68
  checkpoint = torch.load('transcoder_L5.pt')
69
  # Keys: 'layer', 'hidden_dim', 'feature_dim', 'state_dict', 'training_metadata', 'mlp_to_clt_mapping'
70
- ```
71
 
72
- ### 2. `mapping_L{layer}.pt`
73
  Contains MLP→CLT mapping and decoder weights for analysis.
74
 
75
- ```python
76
  mapping = torch.load('mapping_L5.pt')
77
  # Keys: 'layer', 'mlp_to_clt_mapping', 'decoder_weights', 'hidden_dim', 'feature_dim', 'description'
78
 
79
  # mlp_to_clt_mapping: [4096, 8192] - which MLP neurons correlate with each CLT feature
80
  # decoder_weights: [4096, 8192] - CLT → MLP reconstruction weights
81
- ```
82
 
83
  ---
84
 
@@ -86,7 +104,7 @@ mapping = torch.load('mapping_L5.pt')
86
 
87
  ### 1. Load a Transcoder
88
 
89
- ```python
90
  import torch
91
  import torch.nn as nn
92
 
@@ -121,13 +139,13 @@ with torch.no_grad():
121
 
122
  # features: [batch, seq_len, 8192] - sparse interpretable features
123
  # reconstruction: [batch, seq_len, 4096] - reconstructed MLP output
124
- ```
125
 
126
  ### 2. Use MLP→CLT Mapping
127
 
128
  The mapping shows which MLP neurons are correlated with each CLT feature:
129
 
130
- ```python
131
  mapping_data = torch.load('mapping_L10.pt', map_location='cpu')
132
  mlp_to_clt = mapping_data['mlp_to_clt_mapping'] # [4096, 8192]
133
 
@@ -140,13 +158,13 @@ print(f"Top MLP neurons for feature {feature_idx}: {top_mlp_neurons.indices}")
140
  mlp_neuron_idx = 567
141
  top_clt_features = mlp_to_clt[mlp_neuron_idx, :].topk(k=10)
142
  print(f"Top CLT features for MLP neuron {mlp_neuron_idx}: {top_clt_features.indices}")
143
- ```
144
 
145
  ### 3. Replacement Model (Full Integration)
146
 
147
  For direct integration into LLaVA (replace MLPs with CLTs):
148
 
149
- ```python
150
  from transformers import LlavaForConditionalGeneration
151
 
152
  # Load LLaVA
@@ -169,7 +187,7 @@ def replace_mlp_with_clt(module, input, output):
169
  return reconstruction
170
 
171
  model.model.layers[layer_idx].mlp.register_forward_hook(replace_mlp_with_clt)
172
- ```
173
 
174
  ---
175
 
@@ -201,7 +219,7 @@ This work extends Anthropic's Circuit-Tracer methodology to multimodal vision-la
201
 
202
  If you use these transcoders in your research, please cite:
203
 
204
- ```bibtex
205
  @misc{llava15_clts_2025,
206
  title={Cross-Layer Transcoders for LLaVA-1.5-7B},
207
  author={Koko's Dev},
@@ -209,7 +227,7 @@ If you use these transcoders in your research, please cite:
209
  publisher={HuggingFace Hub},
210
  howpublished={\url{https://huggingface.co/KokosDev/llava15-7b-clt}}
211
  }
212
- ```
213
 
214
  ---
215
 
@@ -224,4 +242,3 @@ These transcoders are released under the same license as the base model (Apache
224
  - **Base Model**: [LLaVA-1.5-7B](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
225
  - **Methodology**: Inspired by Anthropic's Circuit-Tracer and sparse autoencoder research
226
  - **Training Data**: Flickr30K, instruction-following datasets
227
-
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ tags:
5
+ - interpretability
6
+ - mechanistic-interpretability
7
+ - vision-language
8
+ - llava
9
+ - sparse-autoencoders
10
+ - circuit-tracer
11
+ - cross-layer-transcoders
12
+ base_model: llava-hf/llava-1.5-7b-hf
13
+ datasets:
14
+ - liuhaotian/llava-instruct-150k
15
+ - nlphuji/flickr30k
16
+ pipeline_tag: image-to-text
17
+ ---
18
+
19
  # LLaVA-1.5-7B Cross-Layer Transcoders (CLTs)
20
 
21
  ## Overview
 
33
 
34
  ## Architecture
35
 
36
+ \`\`\`
37
  Input (MLP hidden state): [batch, seq_len, 4096]
38
 
39
  Transcoder Encoder: LayerNorm + Linear(4096 → 8192) + ReLU
 
43
  Transcoder Decoder: Linear(8192 → 4096)
44
 
45
  Output (MLP reconstruction): [batch, seq_len, 4096]
46
+ \`\`\`
47
 
48
  **Parameters per layer:**
49
  - Hidden dim: 4096
 
55
 
56
  ## Training Details
57
 
58
+ - **Model**: \`llava-hf/llava-1.5-7b-hf\`
59
  - **Dataset**: ~45K multimodal samples (Flickr30K + instruction tasks)
60
  - **Steps per layer**: 5,000
61
  - **Learning rate**: 3e-4 (AdamW)
 
79
 
80
  Each layer has two files:
81
 
82
+ ### 1. \`transcoder_L{layer}.pt\`
83
  Contains the trained transcoder model and training metadata.
84
 
85
+ \`\`\`python
86
  checkpoint = torch.load('transcoder_L5.pt')
87
  # Keys: 'layer', 'hidden_dim', 'feature_dim', 'state_dict', 'training_metadata', 'mlp_to_clt_mapping'
88
+ \`\`\`
89
 
90
+ ### 2. \`mapping_L{layer}.pt\`
91
  Contains MLP→CLT mapping and decoder weights for analysis.
92
 
93
+ \`\`\`python
94
  mapping = torch.load('mapping_L5.pt')
95
  # Keys: 'layer', 'mlp_to_clt_mapping', 'decoder_weights', 'hidden_dim', 'feature_dim', 'description'
96
 
97
  # mlp_to_clt_mapping: [4096, 8192] - which MLP neurons correlate with each CLT feature
98
  # decoder_weights: [4096, 8192] - CLT → MLP reconstruction weights
99
+ \`\`\`
100
 
101
  ---
102
 
 
104
 
105
  ### 1. Load a Transcoder
106
 
107
+ \`\`\`python
108
  import torch
109
  import torch.nn as nn
110
 
 
139
 
140
  # features: [batch, seq_len, 8192] - sparse interpretable features
141
  # reconstruction: [batch, seq_len, 4096] - reconstructed MLP output
142
+ \`\`\`
143
 
144
  ### 2. Use MLP→CLT Mapping
145
 
146
  The mapping shows which MLP neurons are correlated with each CLT feature:
147
 
148
+ \`\`\`python
149
  mapping_data = torch.load('mapping_L10.pt', map_location='cpu')
150
  mlp_to_clt = mapping_data['mlp_to_clt_mapping'] # [4096, 8192]
151
 
 
158
  mlp_neuron_idx = 567
159
  top_clt_features = mlp_to_clt[mlp_neuron_idx, :].topk(k=10)
160
  print(f"Top CLT features for MLP neuron {mlp_neuron_idx}: {top_clt_features.indices}")
161
+ \`\`\`
162
 
163
  ### 3. Replacement Model (Full Integration)
164
 
165
  For direct integration into LLaVA (replace MLPs with CLTs):
166
 
167
+ \`\`\`python
168
  from transformers import LlavaForConditionalGeneration
169
 
170
  # Load LLaVA
 
187
  return reconstruction
188
 
189
  model.model.layers[layer_idx].mlp.register_forward_hook(replace_mlp_with_clt)
190
+ \`\`\`
191
 
192
  ---
193
 
 
219
 
220
  If you use these transcoders in your research, please cite:
221
 
222
+ \`\`\`bibtex
223
  @misc{llava15_clts_2025,
224
  title={Cross-Layer Transcoders for LLaVA-1.5-7B},
225
  author={Koko's Dev},
 
227
  publisher={HuggingFace Hub},
228
  howpublished={\url{https://huggingface.co/KokosDev/llava15-7b-clt}}
229
  }
230
+ \`\`\`
231
 
232
  ---
233
 
 
242
  - **Base Model**: [LLaVA-1.5-7B](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
243
  - **Methodology**: Inspired by Anthropic's Circuit-Tracer and sparse autoencoder research
244
  - **Training Data**: Flickr30K, instruction-following datasets