llm-steering-gpt-2 / src /streamlit_app.py
benchaffe's picture
Update src/streamlit_app.py
c196b6f verified
import streamlit as st
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE
# --- PAGE CONFIG ---
st.set_page_config(layout="wide", page_title="GPT-2 Brain Surgeon")
st.title("🧠 GPT-2 Brain Surgeon: Feature Steering")
st.markdown("""
**Mechanistic Interpretability Demo** This dashboard analyzes the internal "neural activations" of GPT-2 Small and allows you to **steer** the model's output by manually activating specific "concepts" in its brain.
""")
# --- CONFIGURATION (GPT-2 SMALL ONLY) ---
MODEL_NAME = "gpt2-small"
SAE_RELEASE = "gpt2-small-res-jb"
SAE_ID = "blocks.6.hook_resid_pre" # Middle layer (Layer 6)
HOOK_POINT = "blocks.6.hook_resid_pre"
DEVICE = "cpu" # Force CPU for free tier stability
# --- CURATED INTERESTING FEATURES ---
INTERESTING_FEATURES = {
"Select a feature...": (None, "Normal model behavior", 0),
"The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0),
"The 'Space' Feature": (14833, "Fires on planets, orbit, galaxy", 80.0),
"The 'Legal' Feature": (11874, "Fires on court, law, attorney, judge", 70.0),
"The 'Code' Feature": (15339, "Fires on brackets, programming keywords", 50.0),
"The 'Anger' Feature": (7412, "Fires on conflict, hate, arguments", 90.0),
}
# --- LOADER FUNCTIONS (CACHED) ---
# FIX: Removed st.toast from inside this cached function
@st.cache_resource
def load_resources():
# We rely on the caller to show the spinner/toast
model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE)
sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE)
return model, sae
# --- MAIN EXECUTION ---
# Move UI feedback here, OUTSIDE the cached function
try:
with st.spinner("Loading GPT-2 Small & SAE (this may take 30s)..."):
model, sae = load_resources()
st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded")
except Exception as e:
st.error(f"Error loading models: {e}")
st.stop()
# --- MAIN LAYOUT ---
col1, col2 = st.columns([1, 1.5])
# --- COLUMN 1: CONTROLS ---
with col1:
st.subheader("1. 🎛️ Control Panel")
selected_label = st.selectbox(
"Choose a Concept to Inject:",
list(INTERESTING_FEATURES.keys())
)
feature_id, description, default_coeff = INTERESTING_FEATURES[selected_label]
st.info(f"**Description:** {description}")
if feature_id is not None:
st.write(f"**Internal Feature ID:** `{feature_id}`")
steering_coeff = st.slider(
"Injection Strength",
min_value=-150.0,
max_value=150.0,
value=default_coeff,
step=5.0
)
st.caption("Positive = Force concept. Negative = Suppress concept.")
else:
steering_coeff = 0.0
# --- COLUMN 2: EXPERIMENT ---
with col2:
st.subheader("2. 🧪 Experiment")
prompt = st.text_area("Enter a Prompt:", "I think that you are", height=100)
if st.button("Generate Output", type="primary"):
def steering_hook(resid_pre, hook):
if feature_id is not None:
steering_vector = sae.W_dec[feature_id]
resid_pre = resid_pre + (steering_coeff * steering_vector)
return resid_pre
with st.spinner("Running Inference..."):
# 1. Normal Generation
st.markdown("### ⚪ Normal Output")
model.reset_hooks()
normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
st.write(normal_out)
# 2. Steered Generation
if feature_id is not None:
st.markdown(f"### 🔵 Steered Output ('{selected_label}')")
with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]):
steered_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
st.success(steered_out)
else:
st.caption("Select a feature to see the steered output.")
st.divider()
st.caption("Built with transformer_lens and sae_lens.")