hbfreed's picture
Update app.py
0c9a7ea verified
raw
history blame
10 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoConfig
from pathlib import Path
import spaces
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import json
from model import SAE, SteerableOlmo2ForCausalLM
# Initialize model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "allenai/OLMo-2-1124-7B-Instruct"
print("Loading model and tokenizer...")
model = SteerableOlmo2ForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_config = AutoConfig.from_pretrained(model_name)
# Load SAE from Hugging Face Hub
print("Loading SAE from Hugging Face Hub...")
# Download SAE files from your model repository
sae_weights_path = hf_hub_download(
repo_id="open-concept-steering/olmo2-7b-sae-65k-v1",
filename="sae_weights.safetensors"
)
sae_config_path = hf_hub_download(
repo_id="open-concept-steering/olmo2-7b-sae-65k-v1",
filename="sae_config.json"
)
# Load SAE
sae_weights = load_file(sae_weights_path, device=device)
with open(sae_config_path, "r") as f:
sae_config = json.load(f)
sae = SAE(sae_config['input_size'], sae_config['hidden_size']).to(device).to(torch.bfloat16)
sae.load_state_dict(sae_weights)
# Set up steering
steering_layer = model_config.num_hidden_layers // 2 - 1
model.set_sae_and_layer(sae, steering_layer)
# Steering features configuration
STEERING_FEATURES = {
"None": {"feature": None, "default": 0, "name": "No Steering"},
"Batman / Bruce Wayne": {"feature": 758, "default": 9, "name": "🦸 Superhero/Batman"},
"Japan": {"feature": 29940, "default": 8, "name": "πŸ—Ύ Japan"},
"Baseball": {"feature": 65023, "default": 6, "name": "⚾ Baseball"}
}
default_system_prompt = "You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI."
@spaces.GPU
def generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt):
"""Generate both unsteered and steered responses with conversation history"""
if not message:
return history_unsteered, history_steered, ""
# Build messages for unsteered conversation
messages_unsteered = []
if system_prompt:
messages_unsteered.append({"role": "system", "content": system_prompt})
# Add conversation history
for msg in history_unsteered:
messages_unsteered.append({"role": msg["role"], "content": msg["content"]})
# Add current message
messages_unsteered.append({"role": "user", "content": message})
# Format prompt for unsteered
formatted_prompt_unsteered = tokenizer.apply_chat_template(
messages_unsteered,
tokenize=False,
add_generation_prompt=True
)
inputs_unsteered = tokenizer(
formatted_prompt_unsteered,
return_tensors="pt",
padding=True,
return_attention_mask=True
).to(device)
# Generate unsteered response
model.clear_steering()
with torch.inference_mode():
outputs_unsteered = model.generate(
input_ids=inputs_unsteered.input_ids,
attention_mask=inputs_unsteered.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response_unsteered = tokenizer.decode(outputs_unsteered[0], skip_special_tokens=False)
unsteered_response = full_response_unsteered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
# Update unsteered history
history_unsteered.append({"role": "user", "content": message})
history_unsteered.append({"role": "assistant", "content": unsteered_response})
# Generate steered response
if steering_type != "None":
# Build messages for steered conversation
messages_steered = []
if system_prompt:
messages_steered.append({"role": "system", "content": system_prompt})
# Add conversation history
for msg in history_steered:
messages_steered.append({"role": msg["role"], "content": msg["content"]})
# Add current message
messages_steered.append({"role": "user", "content": message})
# Format prompt for steered
formatted_prompt_steered = tokenizer.apply_chat_template(
messages_steered,
tokenize=False,
add_generation_prompt=True
)
inputs_steered = tokenizer(
formatted_prompt_steered,
return_tensors="pt",
padding=True,
return_attention_mask=True
).to(device)
# Apply steering
feature_config = STEERING_FEATURES[steering_type]
steering_value = feature_config["default"] * steering_strength
model.set_steering(feature_config["feature"], steering_value)
with torch.inference_mode():
outputs_steered = model.generate(
input_ids=inputs_steered.input_ids,
attention_mask=inputs_steered.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response_steered = tokenizer.decode(outputs_steered[0], skip_special_tokens=False)
steered_response = full_response_steered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
model.clear_steering()
else:
steered_response = unsteered_response
# Update steered history
history_steered.append({"role": "user", "content": message})
history_steered.append({"role": "assistant", "content": steered_response})
return history_unsteered, history_steered, ""
def clear_chats():
"""Clear both chat histories"""
return [], []
# Create Gradio interface
with gr.Blocks(title="OLMo-2 Feature Steering Demo", theme=gr.themes.Default()) as demo:
gr.Markdown("""
# πŸŽ›οΈ OLMo-2 Feature Steering Demo
This demo showcases how sparse autoencoders (SAEs) can steer OLMo-2's responses by manipulating specific features.
Have a conversation and see how steering changes the model's behavior across multiple turns!
""")
with gr.Row():
with gr.Column(scale=1):
steering_type = gr.Dropdown(
choices=list(STEERING_FEATURES.keys()),
value="None",
label="Steering Type",
info="Choose a feature to steer the model's response"
)
steering_strength = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Steering Strength",
info="Adjust the intensity of the steering effect (higher = more steering, very high values may cause gobbledygook)"
)
system_prompt = gr.Textbox(
label="System Prompt",
value=default_system_prompt,
lines=3
)
clear_btn = gr.Button("πŸ—‘οΈ Clear Chats", variant="secondary")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ€– Original OLMo")
chatbot_unsteered = gr.Chatbot(
label="Unsteered",
height=500,
show_copy_button=True,
type="messages"
)
with gr.Column():
gr.Markdown("### 🎯 Steered OLMo")
chatbot_steered = gr.Chatbot(
label="Steered",
height=500,
show_copy_button=True,
type="messages"
)
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here... (Enter to send, Shift+Enter for new line)",
lines=2,
scale=4
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
# Example questions
gr.Examples(
examples=[
"What's an interesting way to spend a weekend?",
"Tell me about your favorite subject.",
"What should I do with $5?",
"How do you approach solving difficult problems?",
"What's something that makes you excited?",
"Tell me a story about adventure.",
"What advice would you give to someone feeling stuck?"
],
inputs=user_input,
label="Example Questions"
)
# Handle submission
def submit_message(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt):
return generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt)
# Wire up the interface
user_input.submit(
fn=submit_message,
inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt],
outputs=[chatbot_unsteered, chatbot_steered, user_input]
)
submit_btn.click(
fn=submit_message,
inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt],
outputs=[chatbot_unsteered, chatbot_steered, user_input]
)
clear_btn.click(
fn=clear_chats,
outputs=[chatbot_unsteered, chatbot_steered]
)
# Update slider visibility based on steering selection
def update_slider_visibility(steering_type):
return gr.update(visible=(steering_type != "None"))
steering_type.change(
fn=update_slider_visibility,
inputs=steering_type,
outputs=steering_strength
)
if __name__ == "__main__":
demo.launch()