File size: 4,141 Bytes
941f520
af3019e
 
 
941f520
af3019e
7c820c5
af3019e
7c820c5
 
 
 
af3019e
7c820c5
3283876
7c820c5
 
 
 
ff8b3c8
7c820c5
 
 
 
 
 
 
 
 
af3019e
 
c196b6f
7c820c5
 
c196b6f
7c820c5
 
 
af3019e
c196b6f
 
af3019e
c196b6f
 
7c820c5
af3019e
7c820c5
af3019e
 
 
c196b6f
af3019e
7c820c5
af3019e
7c820c5
af3019e
7c820c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3019e
7c820c5
af3019e
7c820c5
af3019e
7c820c5
af3019e
7c820c5
af3019e
 
7c820c5
 
 
af3019e
 
7c820c5
c196b6f
7c820c5
 
 
af3019e
 
c196b6f
7c820c5
 
 
 
 
 
 
af3019e
 
c196b6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE

# --- PAGE CONFIG ---
st.set_page_config(layout="wide", page_title="GPT-2 Brain Surgeon")

st.title("🧠 GPT-2 Brain Surgeon: Feature Steering")
st.markdown("""
**Mechanistic Interpretability Demo** This dashboard analyzes the internal "neural activations" of GPT-2 Small and allows you to **steer** the model's output by manually activating specific "concepts" in its brain.
""")

# --- CONFIGURATION (GPT-2 SMALL ONLY) ---
MODEL_NAME = "gpt2-small"
SAE_RELEASE = "gpt2-small-res-jb" 
SAE_ID = "blocks.6.hook_resid_pre" # Middle layer (Layer 6)
HOOK_POINT = "blocks.6.hook_resid_pre"
DEVICE = "cpu" # Force CPU for free tier stability

# --- CURATED INTERESTING FEATURES ---
INTERESTING_FEATURES = {
    "Select a feature...": (None, "Normal model behavior", 0),
    "The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0),
    "The 'Space' Feature": (14833, "Fires on planets, orbit, galaxy", 80.0),
    "The 'Legal' Feature": (11874, "Fires on court, law, attorney, judge", 70.0),
    "The 'Code' Feature": (15339, "Fires on brackets, programming keywords", 50.0),
    "The 'Anger' Feature": (7412, "Fires on conflict, hate, arguments", 90.0),
}

# --- LOADER FUNCTIONS (CACHED) ---
# FIX: Removed st.toast from inside this cached function
@st.cache_resource
def load_resources():
    # We rely on the caller to show the spinner/toast
    model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE)
    sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE)
    return model, sae

# --- MAIN EXECUTION ---
# Move UI feedback here, OUTSIDE the cached function
try:
    with st.spinner("Loading GPT-2 Small & SAE (this may take 30s)..."):
        model, sae = load_resources()
    st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded")
except Exception as e:
    st.error(f"Error loading models: {e}")
    st.stop()

# --- MAIN LAYOUT ---
col1, col2 = st.columns([1, 1.5]) 

# --- COLUMN 1: CONTROLS ---
with col1:
    st.subheader("1. 🎛️ Control Panel")
    
    selected_label = st.selectbox(
        "Choose a Concept to Inject:",
        list(INTERESTING_FEATURES.keys())
    )
    
    feature_id, description, default_coeff = INTERESTING_FEATURES[selected_label]
    
    st.info(f"**Description:** {description}")
    
    if feature_id is not None:
        st.write(f"**Internal Feature ID:** `{feature_id}`")
        steering_coeff = st.slider(
            "Injection Strength", 
            min_value=-150.0, 
            max_value=150.0, 
            value=default_coeff,
            step=5.0
        )
        st.caption("Positive = Force concept. Negative = Suppress concept.")
    else:
        steering_coeff = 0.0

# --- COLUMN 2: EXPERIMENT ---
with col2:
    st.subheader("2. 🧪 Experiment")
    
    prompt = st.text_area("Enter a Prompt:", "I think that you are", height=100)
    
    if st.button("Generate Output", type="primary"):
        
        def steering_hook(resid_pre, hook):
            if feature_id is not None:
                steering_vector = sae.W_dec[feature_id]
                resid_pre = resid_pre + (steering_coeff * steering_vector)
            return resid_pre

        with st.spinner("Running Inference..."):
            # 1. Normal Generation
            st.markdown("### ⚪ Normal Output")
            model.reset_hooks()
            normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
            st.write(normal_out)
            
            # 2. Steered Generation
            if feature_id is not None:
                st.markdown(f"### 🔵 Steered Output ('{selected_label}')")
                with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]):
                    steered_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
                    st.success(steered_out)
            else:
                st.caption("Select a feature to see the steered output.")

st.divider()
st.caption("Built with transformer_lens and sae_lens.")