import streamlit as st import torch from transformer_lens import HookedTransformer from sae_lens import SAE # --- PAGE CONFIG --- st.set_page_config(layout="wide", page_title="GPT-2 Brain Surgeon") st.title("๐Ÿง  GPT-2 Brain Surgeon: Feature Steering") st.markdown(""" **Mechanistic Interpretability Demo** This dashboard analyzes the internal "neural activations" of GPT-2 Small and allows you to **steer** the model's output by manually activating specific "concepts" in its brain. """) # --- CONFIGURATION (GPT-2 SMALL ONLY) --- MODEL_NAME = "gpt2-small" SAE_RELEASE = "gpt2-small-res-jb" SAE_ID = "blocks.6.hook_resid_pre" # Middle layer (Layer 6) HOOK_POINT = "blocks.6.hook_resid_pre" DEVICE = "cpu" # Force CPU for free tier stability # --- CURATED INTERESTING FEATURES --- INTERESTING_FEATURES = { "Select a feature...": (None, "Normal model behavior", 0), "The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0), "The 'Space' Feature": (14833, "Fires on planets, orbit, galaxy", 80.0), "The 'Legal' Feature": (11874, "Fires on court, law, attorney, judge", 70.0), "The 'Code' Feature": (15339, "Fires on brackets, programming keywords", 50.0), "The 'Anger' Feature": (7412, "Fires on conflict, hate, arguments", 90.0), } # --- LOADER FUNCTIONS (CACHED) --- # FIX: Removed st.toast from inside this cached function @st.cache_resource def load_resources(): # We rely on the caller to show the spinner/toast model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE) sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE) return model, sae # --- MAIN EXECUTION --- # Move UI feedback here, OUTSIDE the cached function try: with st.spinner("Loading GPT-2 Small & SAE (this may take 30s)..."): model, sae = load_resources() st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded") except Exception as e: st.error(f"Error loading models: {e}") st.stop() # --- MAIN LAYOUT --- col1, col2 = st.columns([1, 1.5]) # --- COLUMN 1: CONTROLS --- with col1: st.subheader("1. ๐ŸŽ›๏ธ Control Panel") selected_label = st.selectbox( "Choose a Concept to Inject:", list(INTERESTING_FEATURES.keys()) ) feature_id, description, default_coeff = INTERESTING_FEATURES[selected_label] st.info(f"**Description:** {description}") if feature_id is not None: st.write(f"**Internal Feature ID:** `{feature_id}`") steering_coeff = st.slider( "Injection Strength", min_value=-150.0, max_value=150.0, value=default_coeff, step=5.0 ) st.caption("Positive = Force concept. Negative = Suppress concept.") else: steering_coeff = 0.0 # --- COLUMN 2: EXPERIMENT --- with col2: st.subheader("2. ๐Ÿงช Experiment") prompt = st.text_area("Enter a Prompt:", "I think that you are", height=100) if st.button("Generate Output", type="primary"): def steering_hook(resid_pre, hook): if feature_id is not None: steering_vector = sae.W_dec[feature_id] resid_pre = resid_pre + (steering_coeff * steering_vector) return resid_pre with st.spinner("Running Inference..."): # 1. Normal Generation st.markdown("### โšช Normal Output") model.reset_hooks() normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7) st.write(normal_out) # 2. Steered Generation if feature_id is not None: st.markdown(f"### ๐Ÿ”ต Steered Output ('{selected_label}')") with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]): steered_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7) st.success(steered_out) else: st.caption("Select a feature to see the steered output.") st.divider() st.caption("Built with transformer_lens and sae_lens.")