benchaffe commited on
Commit
c196b6f
·
verified ·
1 Parent(s): 7c820c5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +10 -25
src/streamlit_app.py CHANGED
@@ -19,8 +19,6 @@ HOOK_POINT = "blocks.6.hook_resid_pre"
19
  DEVICE = "cpu" # Force CPU for free tier stability
20
 
21
  # --- CURATED INTERESTING FEATURES ---
22
- # Dictionary of known interesting features for this specific SAE
23
- # (Feature ID, Description, Suggested Steering Strength)
24
  INTERESTING_FEATURES = {
25
  "Select a feature...": (None, "Normal model behavior", 0),
26
  "The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0),
@@ -31,32 +29,31 @@ INTERESTING_FEATURES = {
31
  }
32
 
33
  # --- LOADER FUNCTIONS (CACHED) ---
 
34
  @st.cache_resource
35
  def load_resources():
36
- st.toast("Loading GPT-2 Small... (This may take 30s)", icon="⏳")
37
  model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE)
38
-
39
- st.toast("Loading Sparse Autoencoder...", icon="⏳")
40
  sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE)
41
-
42
  return model, sae
43
 
44
- # Load resources immediately
 
45
  try:
46
- model, sae = load_resources()
 
47
  st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded")
48
  except Exception as e:
49
  st.error(f"Error loading models: {e}")
50
  st.stop()
51
 
52
  # --- MAIN LAYOUT ---
53
- col1, col2 = st.columns([1, 1.5]) # Make right column slightly wider for text
54
 
55
  # --- COLUMN 1: CONTROLS ---
56
  with col1:
57
  st.subheader("1. 🎛️ Control Panel")
58
 
59
- # Selection Dropdown
60
  selected_label = st.selectbox(
61
  "Choose a Concept to Inject:",
62
  list(INTERESTING_FEATURES.keys())
@@ -68,8 +65,6 @@ with col1:
68
 
69
  if feature_id is not None:
70
  st.write(f"**Internal Feature ID:** `{feature_id}`")
71
-
72
- # Slider for Strength
73
  steering_coeff = st.slider(
74
  "Injection Strength",
75
  min_value=-150.0,
@@ -89,25 +84,20 @@ with col2:
89
 
90
  if st.button("Generate Output", type="primary"):
91
 
92
- # Define the Steering Hook
93
  def steering_hook(resid_pre, hook):
94
- # resid_pre shape: [batch, pos, d_model]
95
  if feature_id is not None:
96
- # Get the decoder vector for the specific feature
97
  steering_vector = sae.W_dec[feature_id]
98
- # Inject the vector into the stream
99
  resid_pre = resid_pre + (steering_coeff * steering_vector)
100
  return resid_pre
101
 
102
  with st.spinner("Running Inference..."):
103
- # 1. Normal Generation (Control)
104
  st.markdown("### ⚪ Normal Output")
105
- # Clear hooks just in case
106
  model.reset_hooks()
107
  normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
108
  st.write(normal_out)
109
 
110
- # 2. Steered Generation (Test)
111
  if feature_id is not None:
112
  st.markdown(f"### 🔵 Steered Output ('{selected_label}')")
113
  with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]):
@@ -116,10 +106,5 @@ with col2:
116
  else:
117
  st.caption("Select a feature to see the steered output.")
118
 
119
- # --- FOOTER ---
120
  st.divider()
121
- st.markdown("""
122
- **How this works:** We use a Sparse Autoencoder (SAE) to decompose GPT-2's internal activations into interpretable features.
123
- When you select a feature, we mathematically add its vector to the model's residual stream during generation, forcing the model to "think" about that concept.
124
- *Built with `sae_lens` and `transformer_lens`.*
125
- """)
 
19
  DEVICE = "cpu" # Force CPU for free tier stability
20
 
21
  # --- CURATED INTERESTING FEATURES ---
 
 
22
  INTERESTING_FEATURES = {
23
  "Select a feature...": (None, "Normal model behavior", 0),
24
  "The 'Love' Feature": (1876, "Fires on words like love, passion, heart", 60.0),
 
29
  }
30
 
31
  # --- LOADER FUNCTIONS (CACHED) ---
32
+ # FIX: Removed st.toast from inside this cached function
33
  @st.cache_resource
34
  def load_resources():
35
+ # We rely on the caller to show the spinner/toast
36
  model = HookedTransformer.from_pretrained(MODEL_NAME, device=DEVICE)
 
 
37
  sae, _, _ = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE)
 
38
  return model, sae
39
 
40
+ # --- MAIN EXECUTION ---
41
+ # Move UI feedback here, OUTSIDE the cached function
42
  try:
43
+ with st.spinner("Loading GPT-2 Small & SAE (this may take 30s)..."):
44
+ model, sae = load_resources()
45
  st.success("System Ready: GPT-2 Small + SAE Layer 6 Loaded")
46
  except Exception as e:
47
  st.error(f"Error loading models: {e}")
48
  st.stop()
49
 
50
  # --- MAIN LAYOUT ---
51
+ col1, col2 = st.columns([1, 1.5])
52
 
53
  # --- COLUMN 1: CONTROLS ---
54
  with col1:
55
  st.subheader("1. 🎛️ Control Panel")
56
 
 
57
  selected_label = st.selectbox(
58
  "Choose a Concept to Inject:",
59
  list(INTERESTING_FEATURES.keys())
 
65
 
66
  if feature_id is not None:
67
  st.write(f"**Internal Feature ID:** `{feature_id}`")
 
 
68
  steering_coeff = st.slider(
69
  "Injection Strength",
70
  min_value=-150.0,
 
84
 
85
  if st.button("Generate Output", type="primary"):
86
 
 
87
  def steering_hook(resid_pre, hook):
 
88
  if feature_id is not None:
 
89
  steering_vector = sae.W_dec[feature_id]
 
90
  resid_pre = resid_pre + (steering_coeff * steering_vector)
91
  return resid_pre
92
 
93
  with st.spinner("Running Inference..."):
94
+ # 1. Normal Generation
95
  st.markdown("### ⚪ Normal Output")
 
96
  model.reset_hooks()
97
  normal_out = model.generate(prompt, max_new_tokens=25, verbose=False, temperature=0.7)
98
  st.write(normal_out)
99
 
100
+ # 2. Steered Generation
101
  if feature_id is not None:
102
  st.markdown(f"### 🔵 Steered Output ('{selected_label}')")
103
  with model.hooks(fwd_hooks=[(HOOK_POINT, steering_hook)]):
 
106
  else:
107
  st.caption("Select a feature to see the steered output.")
108
 
 
109
  st.divider()
110
+ st.caption("Built with transformer_lens and sae_lens.")