Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformer_lens import HookedTransformer | |
| from sae_lens import SAE | |
| # --- PAGE CONFIG --- | |
| st.set_page_config(layout="wide", page_title="GPT-2 Brain Surgeon") | |
| st.title("🧠 GPT-2 Brain Surgeon: Feature Steering") | |
| st.markdown(""" | |
| **Mechanistic Interpretability Demo** This dashboard analyzes the internal "neural activations" of GPT-2 Small and allows you to **steer** the model's output by manually activating specific "concepts" in its brain. | |
| """) | |
| # --- CONFIGURATION (GPT-2 SMALL ONLY) --- | |
| MODEL_NAME = "gpt2-small" | |
| SAE_RELEASE = "gpt2-small-res-jb" | |
| SAE_ID = "blocks.6.hook_resid_pre" # Middle layer (Layer 6) | |
| HOOK_POINT = "blocks.6.hook_resid_pre" | |
| DEVICE = "cpu" # Force CPU for free tier stability | |
| # --- CURATED INTERESTING FEATURES --- | |
| INTERESTING_FEATURES = { | |
| "Select a feature...": (None, "Normal model behavior", 0), | |
| "The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0), | |
| "The 'Space' Feature": (14833, "Fires on planets, orbit, galaxy", 80.0), | |
| "The 'Legal' Feature": (11874, "Fires on court, law, attorney, judge", 70.0), | |
| "The 'Code' Feature": (15339, "Fires on brackets, programming keywords", 50.0), | |
| "The 'Anger' Feature": (7412, "Fires on conflict, hate, arguments", 90.0), | |
| } | |
| # --- LOADER FUNCTIONS (CACHED) --- | |
| # FIX: Removed st.toast from inside this cached function | |
| def load_resources(): | |
| # We rely on the caller to show the spinner/toast | |
| model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE) | |
| sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE) | |
| return model, sae | |
| # --- MAIN EXECUTION --- | |
| # Move UI feedback here, OUTSIDE the cached function | |
| try: | |
| with st.spinner("Loading GPT-2 Small & SAE (this may take 30s)..."): | |
| model, sae = load_resources() | |
| st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded") | |
| except Exception as e: | |
| st.error(f"Error loading models: {e}") | |
| st.stop() | |
| # --- MAIN LAYOUT --- | |
| col1, col2 = st.columns([1, 1.5]) | |
| # --- COLUMN 1: CONTROLS --- | |
| with col1: | |
| st.subheader("1. 🎛️ Control Panel") | |
| selected_label = st.selectbox( | |
| "Choose a Concept to Inject:", | |
| list(INTERESTING_FEATURES.keys()) | |
| ) | |
| feature_id, description, default_coeff = INTERESTING_FEATURES[selected_label] | |
| st.info(f"**Description:** {description}") | |
| if feature_id is not None: | |
| st.write(f"**Internal Feature ID:** `{feature_id}`") | |
| steering_coeff = st.slider( | |
| "Injection Strength", | |
| min_value=-150.0, | |
| max_value=150.0, | |
| value=default_coeff, | |
| step=5.0 | |
| ) | |
| st.caption("Positive = Force concept. Negative = Suppress concept.") | |
| else: | |
| steering_coeff = 0.0 | |
| # --- COLUMN 2: EXPERIMENT --- | |
| with col2: | |
| st.subheader("2. 🧪 Experiment") | |
| prompt = st.text_area("Enter a Prompt:", "I think that you are", height=100) | |
| if st.button("Generate Output", type="primary"): | |
| def steering_hook(resid_pre, hook): | |
| if feature_id is not None: | |
| steering_vector = sae.W_dec[feature_id] | |
| resid_pre = resid_pre + (steering_coeff * steering_vector) | |
| return resid_pre | |
| with st.spinner("Running Inference..."): | |
| # 1. Normal Generation | |
| st.markdown("### ⚪ Normal Output") | |
| model.reset_hooks() | |
| normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7) | |
| st.write(normal_out) | |
| # 2. Steered Generation | |
| if feature_id is not None: | |
| st.markdown(f"### 🔵 Steered Output ('{selected_label}')") | |
| with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]): | |
| steered_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7) | |
| st.success(steered_out) | |
| else: | |
| st.caption("Select a feature to see the steered output.") | |
| st.divider() | |
| st.caption("Built with transformer_lens and sae_lens.") |