KwabsHug commited on
Commit
8e41789
·
verified ·
1 Parent(s): 13feaae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -189
app.py CHANGED
@@ -1,195 +1,195 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
-
5
- # Global SAE state
6
- gemma_scope_sae = None
7
- gemma_scope_layer = None
8
- model = None
9
- tokenizer = None
10
-
11
- def load_gemma_scope_sae(layer_num=12):
12
- """Load Gemma Scope SAE for a specific layer."""
13
- global gemma_scope_sae, gemma_scope_layer
14
-
15
- from sae_lens import SAE
16
-
17
- layer_id = f"layer_{layer_num}/width_16k/canonical"
18
-
19
- try:
20
- gemma_scope_sae = SAE.from_pretrained(
21
- release="gemma-scope-2b-pt-res-canonical",
22
- sae_id=layer_id,
23
- device="cuda" if torch.cuda.is_available() else "cpu"
24
- )
25
- gemma_scope_layer = layer_num
26
- return f"Loaded SAE for layer {layer_num}: {layer_id}"
27
- except Exception as e:
28
- return f"Error loading SAE: {str(e)}"
29
-
30
-
31
- @spaces.GPU
32
- def analyze_prompt_features(prompt, top_k=10):
33
- """Analyze which SAE features activate for a given prompt."""
34
- global model, tokenizer, gemma_scope_sae
35
-
36
- from transformers import AutoModelForCausalLM, AutoTokenizer
37
-
38
- top_k = int(top_k)
39
-
40
- # Load Gemma 2 model if needed
41
- if model is None:
42
- model = AutoModelForCausalLM.from_pretrained(
43
- "stvlynn/Gemma-2-2b-Chinese-it",
44
- torch_dtype="auto",
45
- device_map="auto"
46
- )
47
- tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it")
48
-
49
- if gemma_scope_sae is None:
50
- load_result = load_gemma_scope_sae()
51
- if "Error" in load_result:
52
- return load_result
53
-
54
- zero = torch.Tensor([0]).cuda()
55
- model.to(zero.device)
56
-
57
- # Get model activations
58
- inputs = tokenizer(prompt, return_tensors="pt").to(zero.device)
59
- with torch.no_grad():
60
- outputs = model(**inputs, output_hidden_states=True)
61
-
62
- # Run through SAE
63
- layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13
64
- if layer_idx >= len(outputs.hidden_states):
65
- layer_idx = len(outputs.hidden_states) - 1
66
-
67
- hidden_state = outputs.hidden_states[layer_idx]
68
- feature_acts = gemma_scope_sae.encode(hidden_state)
69
-
70
- # Get top activated features
71
- top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k)
72
-
73
- # Build results with Neuronpedia links
74
- layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12
75
- neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k"
76
-
77
- results = ["## Top Activated Features\n"]
78
- results.append("| Feature | Activation | Neuronpedia Link |")
79
- results.append("|---------|------------|------------------|")
80
-
81
- for idx, val in zip(top_features.indices, top_features.values):
82
- feature_id = idx.item()
83
- activation = val.item()
84
- link = f"{neuronpedia_base}/{feature_id}"
85
- results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  results.append("")
88
- results.append("---")
89
- results.append("**How to use:** Click the links to see what concepts each feature represents.")
90
 
91
- return "\n".join(results)
92
-
93
-
94
- def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"):
95
- """Fetch feature data from Neuronpedia API."""
96
- import requests
97
-
98
- feature_id = int(feature_id)
99
- layer = int(layer)
100
-
101
- api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}"
102
-
103
- try:
104
- response = requests.get(api_url, timeout=10)
105
- if response.status_code == 200:
106
- data = response.json()
107
- return format_neuronpedia_feature(data, feature_id, layer, width)
108
- elif response.status_code == 404:
109
- return f"Feature {feature_id} not found at layer {layer}"
110
- else:
111
- return f"API error: {response.status_code}"
112
- except requests.exceptions.Timeout:
113
- return "Request timed out - Neuronpedia may be slow"
114
- except Exception as e:
115
- return f"Error fetching feature: {str(e)}"
116
-
117
-
118
- def format_neuronpedia_feature(data, feature_id, layer, width):
119
- """Format Neuronpedia feature data as markdown."""
120
- results = []
121
- results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)")
122
- results.append("")
123
-
124
- if data.get("description"):
125
- results.append(f"**Description:** {data['description']}")
126
- results.append("")
127
-
128
- if data.get("explanations") and len(data["explanations"]) > 0:
129
- explanation = data["explanations"][0].get("description", "")
130
- if explanation:
131
- results.append(f"**Auto-interpretation:** {explanation}")
132
- results.append("")
133
-
134
- if data.get("activations") and len(data["activations"]) > 0:
135
- results.append("### Top Activating Examples")
136
  results.append("")
137
- for i, act in enumerate(data["activations"][:5]):
138
- tokens = act.get("tokens", [])
139
- values = act.get("values", [])
140
- if tokens:
141
- max_idx = values.index(max(values)) if values else 0
142
- text_parts = []
143
- for j, tok in enumerate(tokens):
144
- if j == max_idx:
145
- text_parts.append(f"**{tok}**")
146
- else:
147
- text_parts.append(tok)
148
- text = "".join(text_parts)
149
- results.append(f"{i+1}. {text}")
150
- results.append("")
151
-
152
- results.append("### Feature Stats")
153
- results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`")
154
- if data.get("max_activation"):
155
- results.append(f"- **Max Activation:** {data['max_activation']:.2f}")
156
- if data.get("frac_nonzero"):
157
- results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%")
158
 
 
 
159
  results.append("")
160
- results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})")
161
-
162
- return "\n".join(results)
163
-
164
-
165
- # Build Gradio interface
166
- with gr.Blocks(title="SAE Feature Analyzer") as demo:
167
- gr.Markdown("# SAE Feature Analyzer")
168
- gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)")
169
-
170
- with gr.Tab("Analyze Prompt"):
171
- prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3)
172
- layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer")
173
- topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features")
174
- analyze_btn = gr.Button("Analyze Features", variant="primary")
175
- analysis_output = gr.Markdown(label="Analysis Results")
176
-
177
- analyze_btn.click(
178
- fn=analyze_prompt_features,
179
- inputs=[prompt_input, topk_slider],
180
- outputs=[analysis_output]
181
- )
182
-
183
- with gr.Tab("Lookup Feature"):
184
- feature_id_input = gr.Number(label="Feature ID", value=0)
185
- layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer")
186
- lookup_btn = gr.Button("Lookup Feature", variant="primary")
187
- lookup_output = gr.Markdown(label="Feature Details")
188
-
189
- lookup_btn.click(
190
- fn=fetch_neuronpedia_feature,
191
- inputs=[feature_id_input, layer_input],
192
- outputs=[lookup_output]
193
- )
194
-
195
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Global SAE state
6
+ gemma_scope_sae = None
7
+ gemma_scope_layer = None
8
+ model = None
9
+ tokenizer = None
10
+
11
+ def load_gemma_scope_sae(layer_num=12):
12
+ """Load Gemma Scope SAE for a specific layer."""
13
+ global gemma_scope_sae, gemma_scope_layer
14
+
15
+ from sae_lens import SAE
16
+
17
+ layer_id = f"layer_{layer_num}/width_16k/canonical"
18
+
19
+ try:
20
+ gemma_scope_sae = SAE.from_pretrained(
21
+ release="gemma-scope-2b-pt-res-canonical",
22
+ sae_id=layer_id,
23
+ device="cuda" if torch.cuda.is_available() else "cpu"
24
+ )
25
+ gemma_scope_layer = layer_num
26
+ return f"Loaded SAE for layer {layer_num}: {layer_id}"
27
+ except Exception as e:
28
+ return f"Error loading SAE: {str(e)}"
29
+
30
+
31
+ @spaces.GPU
32
+ def analyze_prompt_features(prompt, top_k=10):
33
+ """Analyze which SAE features activate for a given prompt."""
34
+ global model, tokenizer, gemma_scope_sae
35
+
36
+ from transformers import AutoModelForCausalLM, AutoTokenizer
37
+
38
+ top_k = int(top_k)
39
+
40
+ # Load Gemma 2 model if needed
41
+ if model is None:
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ "stvlynn/Gemma-2-2b-Chinese-it",
44
+ torch_dtype="auto",
45
+ device_map="auto"
46
+ )
47
+ tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it")
48
+
49
+ if gemma_scope_sae is None:
50
+ load_result = load_gemma_scope_sae()
51
+ if "Error" in load_result:
52
+ return load_result
53
+
54
+ zero = torch.Tensor([0]).cuda()
55
+ model.to(zero.device)
56
+
57
+ # Get model activations
58
+ inputs = tokenizer(prompt, return_tensors="pt").to(zero.device)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs, output_hidden_states=True)
61
+
62
+ # Run through SAE
63
+ layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13
64
+ if layer_idx >= len(outputs.hidden_states):
65
+ layer_idx = len(outputs.hidden_states) - 1
66
+
67
+ hidden_state = outputs.hidden_states[layer_idx]
68
+ feature_acts = gemma_scope_sae.encode(hidden_state)
69
+
70
+ # Get top activated features
71
+ top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k)
72
+
73
+ # Build results with Neuronpedia links
74
+ layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12
75
+ neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k"
76
+
77
+ results = ["## Top Activated Features\n"]
78
+ results.append("| Feature | Activation | Neuronpedia Link |")
79
+ results.append("|---------|------------|------------------|")
80
+
81
+ for idx, val in zip(top_features.indices, top_features.values):
82
+ feature_id = idx.item()
83
+ activation = val.item()
84
+ link = f"{neuronpedia_base}/{feature_id}"
85
+ results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |")
86
+
87
+ results.append("")
88
+ results.append("---")
89
+ results.append("**How to use:** Click the links to see what concepts each feature represents.")
90
+
91
+ return "\n".join(results)
92
+
93
+
94
+ def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"):
95
+ """Fetch feature data from Neuronpedia API."""
96
+ import requests
97
+
98
+ feature_id = int(feature_id)
99
+ layer = int(layer)
100
+
101
+ api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}"
102
+
103
+ try:
104
+ response = requests.get(api_url, timeout=10)
105
+ if response.status_code == 200:
106
+ data = response.json()
107
+ return format_neuronpedia_feature(data, feature_id, layer, width)
108
+ elif response.status_code == 404:
109
+ return f"Feature {feature_id} not found at layer {layer}"
110
+ else:
111
+ return f"API error: {response.status_code}"
112
+ except requests.exceptions.Timeout:
113
+ return "Request timed out - Neuronpedia may be slow"
114
+ except Exception as e:
115
+ return f"Error fetching feature: {str(e)}"
116
+
117
+
118
+ def format_neuronpedia_feature(data, feature_id, layer, width):
119
+ """Format Neuronpedia feature data as markdown."""
120
+ results = []
121
+ results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)")
122
+ results.append("")
123
+
124
+ if data.get("description"):
125
+ results.append(f"**Description:** {data['description']}")
126
  results.append("")
 
 
127
 
128
+ if data.get("explanations") and len(data["explanations"]) > 0:
129
+ explanation = data["explanations"][0].get("description", "")
130
+ if explanation:
131
+ results.append(f"**Auto-interpretation:** {explanation}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  results.append("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ if data.get("activations") and len(data["activations"]) > 0:
135
+ results.append("### Top Activating Examples")
136
  results.append("")
137
+ for i, act in enumerate(data["activations"][:5]):
138
+ tokens = act.get("tokens", [])
139
+ values = act.get("values", [])
140
+ if tokens:
141
+ max_idx = values.index(max(values)) if values else 0
142
+ text_parts = []
143
+ for j, tok in enumerate(tokens):
144
+ if j == max_idx:
145
+ text_parts.append(f"**{tok}**")
146
+ else:
147
+ text_parts.append(tok)
148
+ text = "".join(text_parts)
149
+ results.append(f"{i+1}. {text}")
150
+ results.append("")
151
+
152
+ results.append("### Feature Stats")
153
+ results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`")
154
+ if data.get("max_activation"):
155
+ results.append(f"- **Max Activation:** {data['max_activation']:.2f}")
156
+ if data.get("frac_nonzero"):
157
+ results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%")
158
+
159
+ results.append("")
160
+ results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})")
161
+
162
+ return "\n".join(results)
163
+
164
+
165
+ # Build Gradio interface
166
+ with gr.Blocks(title="SAE Feature Analyzer") as demo:
167
+ gr.Markdown("# SAE Feature Analyzer")
168
+ gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)")
169
+
170
+ with gr.Tab("Analyze Prompt"):
171
+ prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3)
172
+ layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer")
173
+ topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features")
174
+ analyze_btn = gr.Button("Analyze Features", variant="primary")
175
+ analysis_output = gr.Markdown(label="Analysis Results")
176
+
177
+ analyze_btn.click(
178
+ fn=analyze_prompt_features,
179
+ inputs=[prompt_input, topk_slider],
180
+ outputs=[analysis_output]
181
+ )
182
+
183
+ with gr.Tab("Lookup Feature"):
184
+ feature_id_input = gr.Number(label="Feature ID", value=0)
185
+ layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer")
186
+ lookup_btn = gr.Button("Lookup Feature", variant="primary")
187
+ lookup_output = gr.Markdown(label="Feature Details")
188
+
189
+ lookup_btn.click(
190
+ fn=fetch_neuronpedia_feature,
191
+ inputs=[feature_id_input, layer_input],
192
+ outputs=[lookup_output]
193
+ )
194
+
195
+ demo.launch()