Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoConfig | |
| from pathlib import Path | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| import json | |
| from model import SAE, SteerableOlmo2ForCausalLM | |
| # Initialize model and tokenizer | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = "allenai/OLMo-2-1124-7B-Instruct" | |
| print("Loading model and tokenizer...") | |
| model = SteerableOlmo2ForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model_config = AutoConfig.from_pretrained(model_name) | |
| # Load SAE from Hugging Face Hub | |
| print("Loading SAE from Hugging Face Hub...") | |
| # Download SAE files from your model repository | |
| sae_weights_path = hf_hub_download( | |
| repo_id="open-concept-steering/olmo2-7b-sae-65k-v1", | |
| filename="sae_weights.safetensors" | |
| ) | |
| sae_config_path = hf_hub_download( | |
| repo_id="open-concept-steering/olmo2-7b-sae-65k-v1", | |
| filename="sae_config.json" | |
| ) | |
| # Load SAE | |
| sae_weights = load_file(sae_weights_path, device=device) | |
| with open(sae_config_path, "r") as f: | |
| sae_config = json.load(f) | |
| sae = SAE(sae_config['input_size'], sae_config['hidden_size']).to(device).to(torch.bfloat16) | |
| sae.load_state_dict(sae_weights) | |
| # Set up steering | |
| steering_layer = model_config.num_hidden_layers // 2 - 1 | |
| model.set_sae_and_layer(sae, steering_layer) | |
| # Steering features configuration | |
| STEERING_FEATURES = { | |
| "None": {"feature": None, "default": 0, "name": "No Steering"}, | |
| "batman/bruce wayne": {"feature": 758, "default": 11, "name": "π¦Έ Superhero/Batman"}, | |
| "japan": {"feature": 29940, "default": 13, "name": "πΎ Japan"}, | |
| "baseball": {"feature": 65023, "default": 6, "name": "βΎ Baseball"} | |
| } | |
| default_system_prompt = "You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI." | |
| def generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt): | |
| """Generate both unsteered and steered responses with conversation history""" | |
| if not message: | |
| return history_unsteered, history_steered, "" | |
| # Build messages for unsteered conversation | |
| messages_unsteered = [] | |
| if system_prompt: | |
| messages_unsteered.append({"role": "system", "content": system_prompt}) | |
| # Add conversation history | |
| for msg in history_unsteered: | |
| messages_unsteered.append({"role": msg["role"], "content": msg["content"]}) | |
| # Add current message | |
| messages_unsteered.append({"role": "user", "content": message}) | |
| # Format prompt for unsteered | |
| formatted_prompt_unsteered = tokenizer.apply_chat_template( | |
| messages_unsteered, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs_unsteered = tokenizer( | |
| formatted_prompt_unsteered, | |
| return_tensors="pt", | |
| padding=True, | |
| return_attention_mask=True | |
| ).to(device) | |
| # Generate unsteered response | |
| model.clear_steering() | |
| with torch.inference_mode(): | |
| outputs_unsteered = model.generate( | |
| input_ids=inputs_unsteered.input_ids, | |
| attention_mask=inputs_unsteered.attention_mask, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| full_response_unsteered = tokenizer.decode(outputs_unsteered[0], skip_special_tokens=False) | |
| unsteered_response = full_response_unsteered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip() | |
| # Update unsteered history | |
| history_unsteered.append({"role": "user", "content": message}) | |
| history_unsteered.append({"role": "assistant", "content": unsteered_response}) | |
| # Generate steered response | |
| if steering_type != "None": | |
| # Build messages for steered conversation | |
| messages_steered = [] | |
| if system_prompt: | |
| messages_steered.append({"role": "system", "content": system_prompt}) | |
| # Add conversation history | |
| for msg in history_steered: | |
| messages_steered.append({"role": msg["role"], "content": msg["content"]}) | |
| # Add current message | |
| messages_steered.append({"role": "user", "content": message}) | |
| # Format prompt for steered | |
| formatted_prompt_steered = tokenizer.apply_chat_template( | |
| messages_steered, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs_steered = tokenizer( | |
| formatted_prompt_steered, | |
| return_tensors="pt", | |
| padding=True, | |
| return_attention_mask=True | |
| ).to(device) | |
| # Apply steering | |
| feature_config = STEERING_FEATURES[steering_type] | |
| steering_value = feature_config["default"] * steering_strength | |
| model.set_steering(feature_config["feature"], steering_value) | |
| with torch.inference_mode(): | |
| outputs_steered = model.generate( | |
| input_ids=inputs_steered.input_ids, | |
| attention_mask=inputs_steered.attention_mask, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| full_response_steered = tokenizer.decode(outputs_steered[0], skip_special_tokens=False) | |
| steered_response = full_response_steered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip() | |
| model.clear_steering() | |
| else: | |
| steered_response = unsteered_response | |
| # Update steered history | |
| history_steered.append({"role": "user", "content": message}) | |
| history_steered.append({"role": "assistant", "content": steered_response}) | |
| return history_unsteered, history_steered, "" | |
| def clear_chats(): | |
| """Clear both chat histories""" | |
| return [], [] | |
| # Create Gradio interface | |
| with gr.Blocks(title="OLMo-2 Feature Steering Demo", theme=gr.themes.Default()) as demo: | |
| gr.Markdown(""" | |
| # ποΈ OLMo-2 Feature Steering Demo | |
| This demo showcases how sparse autoencoders (SAEs) can steer OLMo-2's responses by manipulating specific features. | |
| Have a conversation and see how steering changes the model's behavior across multiple turns! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| steering_type = gr.Dropdown( | |
| choices=list(STEERING_FEATURES.keys()), | |
| value="None", | |
| label="Steering Type", | |
| info="Choose a feature to steer the model's response" | |
| ) | |
| steering_strength = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Steering Strength", | |
| info="Adjust the intensity of the steering effect (higher = more steering, very high values may cause gobbledygook)" | |
| ) | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| value=default_system_prompt, | |
| lines=3 | |
| ) | |
| clear_btn = gr.Button("ποΈ Clear Chats", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π€ Original OLMo") | |
| chatbot_unsteered = gr.Chatbot( | |
| label="Unsteered", | |
| height=500, | |
| show_copy_button=True, | |
| type="messages" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### π― Steered OLMo") | |
| chatbot_steered = gr.Chatbot( | |
| label="Steered", | |
| height=500, | |
| show_copy_button=True, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here... (Enter to send, Shift+Enter for new line)", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| # Example questions | |
| gr.Examples( | |
| examples=[ | |
| "What's an interesting way to spend a weekend?", | |
| "Tell me about your favorite subject.", | |
| "What should I do with $5?", | |
| "How do you approach solving difficult problems?", | |
| "What's something that makes you excited?", | |
| "Tell me a story about adventure.", | |
| "What advice would you give to someone feeling stuck?" | |
| ], | |
| inputs=user_input, | |
| label="Example Questions" | |
| ) | |
| # Handle submission | |
| def submit_message(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt): | |
| return generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt) | |
| # Wire up the interface | |
| user_input.submit( | |
| fn=submit_message, | |
| inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt], | |
| outputs=[chatbot_unsteered, chatbot_steered, user_input] | |
| ) | |
| submit_btn.click( | |
| fn=submit_message, | |
| inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt], | |
| outputs=[chatbot_unsteered, chatbot_steered, user_input] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chats, | |
| outputs=[chatbot_unsteered, chatbot_steered] | |
| ) | |
| # Update slider visibility based on steering selection | |
| def update_slider_visibility(steering_type): | |
| return gr.update(visible=(steering_type != "None")) | |
| steering_type.change( | |
| fn=update_slider_visibility, | |
| inputs=steering_type, | |
| outputs=steering_strength | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |