Spaces:
Sleeping
Sleeping
File size: 4,141 Bytes
941f520 af3019e 941f520 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 3283876 7c820c5 ff8b3c8 7c820c5 af3019e c196b6f 7c820c5 c196b6f 7c820c5 af3019e c196b6f af3019e c196b6f 7c820c5 af3019e 7c820c5 af3019e c196b6f af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 af3019e 7c820c5 c196b6f 7c820c5 af3019e c196b6f 7c820c5 af3019e c196b6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import streamlit as st
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE
# --- PAGE CONFIG ---
st.set_page_config(layout="wide", page_title="GPT-2 Brain Surgeon")
st.title("🧠 GPT-2 Brain Surgeon: Feature Steering")
st.markdown("""
**Mechanistic Interpretability Demo** This dashboard analyzes the internal "neural activations" of GPT-2 Small and allows you to **steer** the model's output by manually activating specific "concepts" in its brain.
""")
# --- CONFIGURATION (GPT-2 SMALL ONLY) ---
MODEL_NAME = "gpt2-small"
SAE_RELEASE = "gpt2-small-res-jb"
SAE_ID = "blocks.6.hook_resid_pre" # Middle layer (Layer 6)
HOOK_POINT = "blocks.6.hook_resid_pre"
DEVICE = "cpu" # Force CPU for free tier stability
# --- CURATED INTERESTING FEATURES ---
INTERESTING_FEATURES = {
"Select a feature...": (None, "Normal model behavior", 0),
"The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0),
"The 'Space' Feature": (14833, "Fires on planets, orbit, galaxy", 80.0),
"The 'Legal' Feature": (11874, "Fires on court, law, attorney, judge", 70.0),
"The 'Code' Feature": (15339, "Fires on brackets, programming keywords", 50.0),
"The 'Anger' Feature": (7412, "Fires on conflict, hate, arguments", 90.0),
}
# --- LOADER FUNCTIONS (CACHED) ---
# FIX: Removed st.toast from inside this cached function
@st.cache_resource
def load_resources():
# We rely on the caller to show the spinner/toast
model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE)
sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE)
return model, sae
# --- MAIN EXECUTION ---
# Move UI feedback here, OUTSIDE the cached function
try:
with st.spinner("Loading GPT-2 Small & SAE (this may take 30s)..."):
model, sae = load_resources()
st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded")
except Exception as e:
st.error(f"Error loading models: {e}")
st.stop()
# --- MAIN LAYOUT ---
col1, col2 = st.columns([1, 1.5])
# --- COLUMN 1: CONTROLS ---
with col1:
st.subheader("1. 🎛️ Control Panel")
selected_label = st.selectbox(
"Choose a Concept to Inject:",
list(INTERESTING_FEATURES.keys())
)
feature_id, description, default_coeff = INTERESTING_FEATURES[selected_label]
st.info(f"**Description:** {description}")
if feature_id is not None:
st.write(f"**Internal Feature ID:** `{feature_id}`")
steering_coeff = st.slider(
"Injection Strength",
min_value=-150.0,
max_value=150.0,
value=default_coeff,
step=5.0
)
st.caption("Positive = Force concept. Negative = Suppress concept.")
else:
steering_coeff = 0.0
# --- COLUMN 2: EXPERIMENT ---
with col2:
st.subheader("2. 🧪 Experiment")
prompt = st.text_area("Enter a Prompt:", "I think that you are", height=100)
if st.button("Generate Output", type="primary"):
def steering_hook(resid_pre, hook):
if feature_id is not None:
steering_vector = sae.W_dec[feature_id]
resid_pre = resid_pre + (steering_coeff * steering_vector)
return resid_pre
with st.spinner("Running Inference..."):
# 1. Normal Generation
st.markdown("### ⚪ Normal Output")
model.reset_hooks()
normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
st.write(normal_out)
# 2. Steered Generation
if feature_id is not None:
st.markdown(f"### 🔵 Steered Output ('{selected_label}')")
with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]):
steered_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
st.success(steered_out)
else:
st.caption("Select a feature to see the steered output.")
st.divider()
st.caption("Built with transformer_lens and sae_lens.") |