File size: 6,664 Bytes
8e41789
 
 
13feaae
8e41789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13feaae
 
8e41789
 
 
 
13feaae
 
8e41789
 
13feaae
8e41789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
import spaces
import torch

# Global SAE state
gemma_scope_sae = None
gemma_scope_layer = None
model = None
tokenizer = None

def load_gemma_scope_sae(layer_num=12):
 """Load Gemma Scope SAE for a specific layer."""
 global gemma_scope_sae, gemma_scope_layer

 from sae_lens import SAE

 layer_id = f"layer_{layer_num}/width_16k/canonical"

 try:
     gemma_scope_sae = SAE.from_pretrained(
         release="gemma-scope-2b-pt-res-canonical",
         sae_id=layer_id,
         device="cuda" if torch.cuda.is_available() else "cpu"
     )
     gemma_scope_layer = layer_num
     return f"Loaded SAE for layer {layer_num}: {layer_id}"
 except Exception as e:
     return f"Error loading SAE: {str(e)}"


@spaces.GPU
def analyze_prompt_features(prompt, top_k=10):
 """Analyze which SAE features activate for a given prompt."""
 global model, tokenizer, gemma_scope_sae

 from transformers import AutoModelForCausalLM, AutoTokenizer

 top_k = int(top_k)

 # Load Gemma 2 model if needed
 if model is None:
     model = AutoModelForCausalLM.from_pretrained(
         "stvlynn/Gemma-2-2b-Chinese-it",
         torch_dtype="auto",
         device_map="auto"
     )
     tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it")

 if gemma_scope_sae is None:
     load_result = load_gemma_scope_sae()
     if "Error" in load_result:
         return load_result

 zero = torch.Tensor([0]).cuda()
 model.to(zero.device)

 # Get model activations
 inputs = tokenizer(prompt, return_tensors="pt").to(zero.device)
 with torch.no_grad():
     outputs = model(**inputs, output_hidden_states=True)

 # Run through SAE
 layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13
 if layer_idx >= len(outputs.hidden_states):
     layer_idx = len(outputs.hidden_states) - 1

 hidden_state = outputs.hidden_states[layer_idx]
 feature_acts = gemma_scope_sae.encode(hidden_state)

 # Get top activated features
 top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k)

 # Build results with Neuronpedia links
 layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12
 neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k"

 results = ["## Top Activated Features\n"]
 results.append("| Feature | Activation | Neuronpedia Link |")
 results.append("|---------|------------|------------------|")

 for idx, val in zip(top_features.indices, top_features.values):
     feature_id = idx.item()
     activation = val.item()
     link = f"{neuronpedia_base}/{feature_id}"
     results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |")

 results.append("")
 results.append("---")
 results.append("**How to use:** Click the links to see what concepts each feature represents.")

 return "\n".join(results)


def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"):
 """Fetch feature data from Neuronpedia API."""
 import requests

 feature_id = int(feature_id)
 layer = int(layer)

 api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}"

 try:
     response = requests.get(api_url, timeout=10)
     if response.status_code == 200:
         data = response.json()
         return format_neuronpedia_feature(data, feature_id, layer, width)
     elif response.status_code == 404:
         return f"Feature {feature_id} not found at layer {layer}"
     else:
         return f"API error: {response.status_code}"
 except requests.exceptions.Timeout:
     return "Request timed out - Neuronpedia may be slow"
 except Exception as e:
     return f"Error fetching feature: {str(e)}"


def format_neuronpedia_feature(data, feature_id, layer, width):
 """Format Neuronpedia feature data as markdown."""
 results = []
 results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)")
 results.append("")

 if data.get("description"):
     results.append(f"**Description:** {data['description']}")
     results.append("")

 if data.get("explanations") and len(data["explanations"]) > 0:
     explanation = data["explanations"][0].get("description", "")
     if explanation:
         results.append(f"**Auto-interpretation:** {explanation}")
         results.append("")

 if data.get("activations") and len(data["activations"]) > 0:
     results.append("### Top Activating Examples")
     results.append("")
     for i, act in enumerate(data["activations"][:5]):
         tokens = act.get("tokens", [])
         values = act.get("values", [])
         if tokens:
             max_idx = values.index(max(values)) if values else 0
             text_parts = []
             for j, tok in enumerate(tokens):
                 if j == max_idx:
                     text_parts.append(f"**{tok}**")
                 else:
                     text_parts.append(tok)
             text = "".join(text_parts)
             results.append(f"{i+1}. {text}")
     results.append("")

 results.append("### Feature Stats")
 results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`")
 if data.get("max_activation"):
     results.append(f"- **Max Activation:** {data['max_activation']:.2f}")
 if data.get("frac_nonzero"):
     results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%")

 results.append("")
 results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})")

 return "\n".join(results)


# Build Gradio interface
with gr.Blocks(title="SAE Feature Analyzer") as demo:
 gr.Markdown("# SAE Feature Analyzer")
 gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)")

 with gr.Tab("Analyze Prompt"):
     prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3)
     layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer")
     topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features")
     analyze_btn = gr.Button("Analyze Features", variant="primary")
     analysis_output = gr.Markdown(label="Analysis Results")

     analyze_btn.click(
         fn=analyze_prompt_features,
         inputs=[prompt_input, topk_slider],
         outputs=[analysis_output]
     )

 with gr.Tab("Lookup Feature"):
     feature_id_input = gr.Number(label="Feature ID", value=0)
     layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer")
     lookup_btn = gr.Button("Lookup Feature", variant="primary")
     lookup_output = gr.Markdown(label="Feature Details")

     lookup_btn.click(
         fn=fetch_neuronpedia_feature,
         inputs=[feature_id_input, layer_input],
         outputs=[lookup_output]
     )

demo.launch()