Spaces:
Sleeping
Sleeping
File size: 6,664 Bytes
8e41789 13feaae 8e41789 13feaae 8e41789 13feaae 8e41789 13feaae 8e41789 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import gradio as gr
import spaces
import torch
# Global SAE state
gemma_scope_sae = None
gemma_scope_layer = None
model = None
tokenizer = None
def load_gemma_scope_sae(layer_num=12):
"""Load Gemma Scope SAE for a specific layer."""
global gemma_scope_sae, gemma_scope_layer
from sae_lens import SAE
layer_id = f"layer_{layer_num}/width_16k/canonical"
try:
gemma_scope_sae = SAE.from_pretrained(
release="gemma-scope-2b-pt-res-canonical",
sae_id=layer_id,
device="cuda" if torch.cuda.is_available() else "cpu"
)
gemma_scope_layer = layer_num
return f"Loaded SAE for layer {layer_num}: {layer_id}"
except Exception as e:
return f"Error loading SAE: {str(e)}"
@spaces.GPU
def analyze_prompt_features(prompt, top_k=10):
"""Analyze which SAE features activate for a given prompt."""
global model, tokenizer, gemma_scope_sae
from transformers import AutoModelForCausalLM, AutoTokenizer
top_k = int(top_k)
# Load Gemma 2 model if needed
if model is None:
model = AutoModelForCausalLM.from_pretrained(
"stvlynn/Gemma-2-2b-Chinese-it",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it")
if gemma_scope_sae is None:
load_result = load_gemma_scope_sae()
if "Error" in load_result:
return load_result
zero = torch.Tensor([0]).cuda()
model.to(zero.device)
# Get model activations
inputs = tokenizer(prompt, return_tensors="pt").to(zero.device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Run through SAE
layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13
if layer_idx >= len(outputs.hidden_states):
layer_idx = len(outputs.hidden_states) - 1
hidden_state = outputs.hidden_states[layer_idx]
feature_acts = gemma_scope_sae.encode(hidden_state)
# Get top activated features
top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k)
# Build results with Neuronpedia links
layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12
neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k"
results = ["## Top Activated Features\n"]
results.append("| Feature | Activation | Neuronpedia Link |")
results.append("|---------|------------|------------------|")
for idx, val in zip(top_features.indices, top_features.values):
feature_id = idx.item()
activation = val.item()
link = f"{neuronpedia_base}/{feature_id}"
results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |")
results.append("")
results.append("---")
results.append("**How to use:** Click the links to see what concepts each feature represents.")
return "\n".join(results)
def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"):
"""Fetch feature data from Neuronpedia API."""
import requests
feature_id = int(feature_id)
layer = int(layer)
api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}"
try:
response = requests.get(api_url, timeout=10)
if response.status_code == 200:
data = response.json()
return format_neuronpedia_feature(data, feature_id, layer, width)
elif response.status_code == 404:
return f"Feature {feature_id} not found at layer {layer}"
else:
return f"API error: {response.status_code}"
except requests.exceptions.Timeout:
return "Request timed out - Neuronpedia may be slow"
except Exception as e:
return f"Error fetching feature: {str(e)}"
def format_neuronpedia_feature(data, feature_id, layer, width):
"""Format Neuronpedia feature data as markdown."""
results = []
results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)")
results.append("")
if data.get("description"):
results.append(f"**Description:** {data['description']}")
results.append("")
if data.get("explanations") and len(data["explanations"]) > 0:
explanation = data["explanations"][0].get("description", "")
if explanation:
results.append(f"**Auto-interpretation:** {explanation}")
results.append("")
if data.get("activations") and len(data["activations"]) > 0:
results.append("### Top Activating Examples")
results.append("")
for i, act in enumerate(data["activations"][:5]):
tokens = act.get("tokens", [])
values = act.get("values", [])
if tokens:
max_idx = values.index(max(values)) if values else 0
text_parts = []
for j, tok in enumerate(tokens):
if j == max_idx:
text_parts.append(f"**{tok}**")
else:
text_parts.append(tok)
text = "".join(text_parts)
results.append(f"{i+1}. {text}")
results.append("")
results.append("### Feature Stats")
results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`")
if data.get("max_activation"):
results.append(f"- **Max Activation:** {data['max_activation']:.2f}")
if data.get("frac_nonzero"):
results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%")
results.append("")
results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})")
return "\n".join(results)
# Build Gradio interface
with gr.Blocks(title="SAE Feature Analyzer") as demo:
gr.Markdown("# SAE Feature Analyzer")
gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)")
with gr.Tab("Analyze Prompt"):
prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3)
layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer")
topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features")
analyze_btn = gr.Button("Analyze Features", variant="primary")
analysis_output = gr.Markdown(label="Analysis Results")
analyze_btn.click(
fn=analyze_prompt_features,
inputs=[prompt_input, topk_slider],
outputs=[analysis_output]
)
with gr.Tab("Lookup Feature"):
feature_id_input = gr.Number(label="Feature ID", value=0)
layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer")
lookup_btn = gr.Button("Lookup Feature", variant="primary")
lookup_output = gr.Markdown(label="Feature Details")
lookup_btn.click(
fn=fetch_neuronpedia_feature,
inputs=[feature_id_input, layer_input],
outputs=[lookup_output]
)
demo.launch() |