hbfreed commited on
Commit
4e4764b
·
verified ·
1 Parent(s): daea878

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +282 -0
  2. gitattributes +35 -0
  3. model.py +124 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoConfig
4
+ from pathlib import Path
5
+ import spaces
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file
8
+ import json
9
+ from model import SAE, SteerableOlmo2ForCausalLM
10
+
11
+ # Initialize model and tokenizer
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model_name = "allenai/OLMo-2-1124-7B-Instruct"
14
+
15
+ print("Loading model and tokenizer...")
16
+ model = SteerableOlmo2ForCausalLM.from_pretrained(
17
+ model_name,
18
+ torch_dtype=torch.bfloat16
19
+ ).to(device)
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model_config = AutoConfig.from_pretrained(model_name)
22
+
23
+ # Load SAE from Hugging Face Hub
24
+ print("Loading SAE from Hugging Face Hub...")
25
+
26
+ # Download SAE files from your model repository
27
+ sae_weights_path = hf_hub_download(
28
+ repo_id="open-concept-steering/olmo2-7b-sae-65k-v1",
29
+ filename="sae_weights.safetensors"
30
+ )
31
+ sae_config_path = hf_hub_download(
32
+ repo_id="open-concept-steering/olmo2-7b-sae-65k-v1",
33
+ filename="sae_config.json"
34
+ )
35
+
36
+ # Load SAE
37
+ sae_weights = load_file(sae_weights_path, device=device)
38
+ with open(sae_config_path, "r") as f:
39
+ sae_config = json.load(f)
40
+
41
+ sae = SAE(sae_config['input_size'], sae_config['hidden_size']).to(device).to(torch.bfloat16)
42
+ sae.load_state_dict(sae_weights)
43
+
44
+ # Set up steering
45
+ steering_layer = model_config.num_hidden_layers // 2 - 1
46
+ model.set_sae_and_layer(sae, steering_layer)
47
+
48
+ # Steering features configuration
49
+ STEERING_FEATURES = {
50
+ "None": {"feature": None, "default": 0, "name": "No Steering"},
51
+ "batman/bruce wayne": {"feature": 758, "default": 11, "name": "🦸 Superhero/Batman"},
52
+ "japan": {"feature": 29940, "default": 13, "name": "🗾 Japan"},
53
+ "baseball": {"feature": 65023, "default": 6, "name": "⚾ Baseball"}
54
+ }
55
+
56
+ default_system_prompt = "You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI."
57
+
58
+ @spaces.GPU
59
+ def generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt):
60
+ """Generate both unsteered and steered responses with conversation history"""
61
+
62
+ if not message:
63
+ return history_unsteered, history_steered, ""
64
+
65
+ # Build messages for unsteered conversation
66
+ messages_unsteered = []
67
+ if system_prompt:
68
+ messages_unsteered.append({"role": "system", "content": system_prompt})
69
+
70
+ # Add conversation history
71
+ for msg in history_unsteered:
72
+ messages_unsteered.append({"role": msg["role"], "content": msg["content"]})
73
+
74
+ # Add current message
75
+ messages_unsteered.append({"role": "user", "content": message})
76
+
77
+ # Format prompt for unsteered
78
+ formatted_prompt_unsteered = tokenizer.apply_chat_template(
79
+ messages_unsteered,
80
+ tokenize=False,
81
+ add_generation_prompt=True
82
+ )
83
+
84
+ inputs_unsteered = tokenizer(
85
+ formatted_prompt_unsteered,
86
+ return_tensors="pt",
87
+ padding=True,
88
+ return_attention_mask=True
89
+ ).to(device)
90
+
91
+ # Generate unsteered response
92
+ model.clear_steering()
93
+ with torch.inference_mode():
94
+ outputs_unsteered = model.generate(
95
+ input_ids=inputs_unsteered.input_ids,
96
+ attention_mask=inputs_unsteered.attention_mask,
97
+ max_new_tokens=256,
98
+ temperature=0.7,
99
+ top_p=0.9,
100
+ do_sample=True,
101
+ pad_token_id=tokenizer.eos_token_id
102
+ )
103
+
104
+ full_response_unsteered = tokenizer.decode(outputs_unsteered[0], skip_special_tokens=False)
105
+ unsteered_response = full_response_unsteered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
106
+
107
+ # Update unsteered history
108
+ history_unsteered.append({"role": "user", "content": message})
109
+ history_unsteered.append({"role": "assistant", "content": unsteered_response})
110
+
111
+ # Generate steered response
112
+ if steering_type != "None":
113
+ # Build messages for steered conversation
114
+ messages_steered = []
115
+ if system_prompt:
116
+ messages_steered.append({"role": "system", "content": system_prompt})
117
+
118
+ # Add conversation history
119
+ for msg in history_steered:
120
+ messages_steered.append({"role": msg["role"], "content": msg["content"]})
121
+
122
+ # Add current message
123
+ messages_steered.append({"role": "user", "content": message})
124
+
125
+ # Format prompt for steered
126
+ formatted_prompt_steered = tokenizer.apply_chat_template(
127
+ messages_steered,
128
+ tokenize=False,
129
+ add_generation_prompt=True
130
+ )
131
+
132
+ inputs_steered = tokenizer(
133
+ formatted_prompt_steered,
134
+ return_tensors="pt",
135
+ padding=True,
136
+ return_attention_mask=True
137
+ ).to(device)
138
+
139
+ # Apply steering
140
+ feature_config = STEERING_FEATURES[steering_type]
141
+ steering_value = feature_config["default"] * steering_strength
142
+ model.set_steering(feature_config["feature"], steering_value)
143
+
144
+ with torch.inference_mode():
145
+ outputs_steered = model.generate(
146
+ input_ids=inputs_steered.input_ids,
147
+ attention_mask=inputs_steered.attention_mask,
148
+ max_new_tokens=256,
149
+ temperature=0.7,
150
+ top_p=0.9,
151
+ do_sample=True,
152
+ pad_token_id=tokenizer.eos_token_id
153
+ )
154
+
155
+ full_response_steered = tokenizer.decode(outputs_steered[0], skip_special_tokens=False)
156
+ steered_response = full_response_steered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
157
+ model.clear_steering()
158
+ else:
159
+ steered_response = unsteered_response
160
+
161
+ # Update steered history
162
+ history_steered.append({"role": "user", "content": message})
163
+ history_steered.append({"role": "assistant", "content": steered_response})
164
+
165
+ return history_unsteered, history_steered, ""
166
+
167
+ def clear_chats():
168
+ """Clear both chat histories"""
169
+ return [], []
170
+
171
+ # Create Gradio interface
172
+ with gr.Blocks(title="OLMo-2 Feature Steering Demo", theme=gr.themes.Default()) as demo:
173
+ gr.Markdown("""
174
+ # 🎛️ OLMo-2 Feature Steering Demo
175
+
176
+ This demo showcases how sparse autoencoders (SAEs) can steer OLMo-2's responses by manipulating specific features.
177
+ Have a conversation and see how steering changes the model's behavior across multiple turns!
178
+ """)
179
+
180
+ with gr.Row():
181
+ with gr.Column(scale=1):
182
+ steering_type = gr.Dropdown(
183
+ choices=list(STEERING_FEATURES.keys()),
184
+ value="None",
185
+ label="Steering Type",
186
+ info="Choose a feature to steer the model's response"
187
+ )
188
+
189
+ steering_strength = gr.Slider(
190
+ minimum=0.5,
191
+ maximum=2.0,
192
+ value=1.0,
193
+ step=0.1,
194
+ label="Steering Strength",
195
+ info="Adjust the intensity of the steering effect (higher = more steering, very high values may cause gobbledygook)"
196
+ )
197
+
198
+ system_prompt = gr.Textbox(
199
+ label="System Prompt",
200
+ value=default_system_prompt,
201
+ lines=3
202
+ )
203
+
204
+ clear_btn = gr.Button("🗑️ Clear Chats", variant="secondary")
205
+
206
+ with gr.Row():
207
+ with gr.Column():
208
+ gr.Markdown("### 🤖 Original OLMo")
209
+ chatbot_unsteered = gr.Chatbot(
210
+ label="Unsteered",
211
+ height=500,
212
+ show_copy_button=True,
213
+ type="messages"
214
+ )
215
+
216
+ with gr.Column():
217
+ gr.Markdown("### 🎯 Steered OLMo")
218
+ chatbot_steered = gr.Chatbot(
219
+ label="Steered",
220
+ height=500,
221
+ show_copy_button=True,
222
+ type="messages"
223
+ )
224
+
225
+ with gr.Row():
226
+ user_input = gr.Textbox(
227
+ label="Your Message",
228
+ placeholder="Type your message here... (Enter to send, Shift+Enter for new line)",
229
+ lines=2,
230
+ scale=4
231
+ )
232
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
233
+
234
+ # Example questions
235
+ gr.Examples(
236
+ examples=[
237
+ "What's an interesting way to spend a weekend?",
238
+ "Tell me about your favorite subject.",
239
+ "What should I do with $5?",
240
+ "How do you approach solving difficult problems?",
241
+ "What's something that makes you excited?",
242
+ "Tell me a story about adventure.",
243
+ "What advice would you give to someone feeling stuck?"
244
+ ],
245
+ inputs=user_input,
246
+ label="Example Questions"
247
+ )
248
+
249
+ # Handle submission
250
+ def submit_message(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt):
251
+ return generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt)
252
+
253
+ # Wire up the interface
254
+ user_input.submit(
255
+ fn=submit_message,
256
+ inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt],
257
+ outputs=[chatbot_unsteered, chatbot_steered, user_input]
258
+ )
259
+
260
+ submit_btn.click(
261
+ fn=submit_message,
262
+ inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt],
263
+ outputs=[chatbot_unsteered, chatbot_steered, user_input]
264
+ )
265
+
266
+ clear_btn.click(
267
+ fn=clear_chats,
268
+ outputs=[chatbot_unsteered, chatbot_steered]
269
+ )
270
+
271
+ # Update slider visibility based on steering selection
272
+ def update_slider_visibility(steering_type):
273
+ return gr.update(visible=(steering_type != "None"))
274
+
275
+ steering_type.change(
276
+ fn=update_slider_visibility,
277
+ inputs=steering_type,
278
+ outputs=steering_strength
279
+ )
280
+
281
+ if __name__ == "__main__":
282
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import Olmo2ForCausalLM
5
+ class SAE(nn.Module):
6
+ def __init__(self, input_size, hidden_size, init_scale=0.1):
7
+ super().__init__()
8
+
9
+ # Store dimensions
10
+ self.input_size = input_size
11
+ self.hidden_size = hidden_size
12
+
13
+ # Initialize as before
14
+ self.encode = nn.Linear(input_size, hidden_size, bias=True)
15
+ self.decode = nn.Linear(hidden_size, input_size, bias=True)
16
+
17
+ with torch.no_grad():
18
+ # Random directions
19
+ decoder_weights = torch.randn(input_size, hidden_size)
20
+ # Normalize columns
21
+ decoder_weights = decoder_weights / torch.linalg.vector_norm(decoder_weights, dim=0, keepdim=True)
22
+ # Scale by random values between 0.05 and 1.0
23
+ scales = torch.rand(hidden_size) * 0.95 + 0.05
24
+ decoder_weights = decoder_weights * scales
25
+
26
+ self.decode.weight.data = decoder_weights
27
+ self.encode.weight.data = decoder_weights.T.contiguous()
28
+ self.encode.bias.data.zero_() #zero in place
29
+ self.decode.bias.data.zero_()
30
+
31
+ self.constrain_weights()
32
+
33
+ @property
34
+ def device(self):
35
+ """Return the device the model parameters are on"""
36
+ return next(self.parameters()).device
37
+
38
+ def constrain_weights(self):
39
+ """Constrain the decoder weights to have unit norm."""
40
+ with torch.no_grad():
41
+ decoder_norm = torch.linalg.vector_norm(self.decode.weight, dim=0, keepdim=True)
42
+ self.decode.weight.data = self.decode.weight.data / decoder_norm
43
+
44
+ def forward(self, x):
45
+ features = F.relu(self.encode(x))
46
+ reconstruction = self.decode(features)
47
+ return reconstruction, features
48
+
49
+ def get_decoder_norms(self):
50
+ # returns a 1-D tensor (hidden_size,) on the right device/dtype
51
+ return torch.linalg.vector_norm(self.decode.weight, dim=0)
52
+
53
+
54
+ @property
55
+ def W_dec(self):
56
+ """Return decoder weights for easier access during analysis"""
57
+ return self.decode.weight
58
+
59
+ def compute_loss(self, x, recon, feats, lambda_):
60
+ # reconstruction term — sum over feature-dim, mean over batch
61
+ recon_mse = (recon - x).pow(2).sum(-1).mean()
62
+
63
+ # sparsity term — L1 on feature activations * current decoder-column norms
64
+ sparsity = (feats.abs() * self.get_decoder_norms()).sum(1).mean()
65
+
66
+ return recon_mse + lambda_ * sparsity
67
+
68
+ class SteerableOlmo2ForCausalLM(Olmo2ForCausalLM):
69
+ def __init__(self, config):
70
+ super().__init__(config)
71
+ self.steering_layer = None
72
+ self.sae = None
73
+ self.steering_features = {}
74
+ self.steering_hook = None
75
+ self.sae_max = None
76
+
77
+ def set_sae_and_layer(self, sae, layer):
78
+ self.sae = sae
79
+ self.steering_layer = layer
80
+ self._register_steering_hook()
81
+
82
+ def set_sae_max(self, sae_max):
83
+ self.sae_max = sae_max
84
+
85
+ def set_steering(self, feature_idx, value, *, as_multiple_of_max=False):
86
+ if as_multiple_of_max and self.sae_max is not None:
87
+ value = float(value) * float(self.sae_max[feature_idx])
88
+ self.steering_features[feature_idx] = value
89
+
90
+ def clear_steering(self):
91
+ self.steering_features = {}
92
+
93
+ @torch.no_grad()
94
+ def _steering_hook_fn(self, module, input, output):
95
+ if not self.steering_features or self.sae is None:
96
+ return output
97
+
98
+ hidden_states = output[0]
99
+ feats = self.sae.encode(hidden_states)
100
+ recon = self.sae.decode(feats)
101
+ error = hidden_states - recon
102
+
103
+ feats_steered = feats.clone()
104
+ for idx, clamp_value in self.steering_features.items():
105
+ feats_steered[..., idx] = clamp_value
106
+
107
+ recon_steered = self.sae.decode(feats_steered)
108
+ hidden_steered = recon_steered + error
109
+
110
+ return (hidden_steered,) + output[1:]
111
+
112
+ def _register_steering_hook(self):
113
+ if self.steering_hook is not None:
114
+ self.steering_hook.remove()
115
+ self.steering_hook = None
116
+
117
+ if self.steering_layer is not None:
118
+ target_layer = self.model.layers[self.steering_layer]
119
+ self.steering_hook = target_layer.register_forward_hook(self._steering_hook_fn)
120
+
121
+ def remove_steering_hook(self):
122
+ if self.steering_hook is not None:
123
+ self.steering_hook.remove()
124
+ self.steering_hook = None
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ spaces
5
+ safetensors
6
+ huggingface_hub