KwabsHug's picture
Update app.py
8e41789 verified
import gradio as gr
import spaces
import torch
# Global SAE state
gemma_scope_sae = None
gemma_scope_layer = None
model = None
tokenizer = None
def load_gemma_scope_sae(layer_num=12):
"""Load Gemma Scope SAE for a specific layer."""
global gemma_scope_sae, gemma_scope_layer
from sae_lens import SAE
layer_id = f"layer_{layer_num}/width_16k/canonical"
try:
gemma_scope_sae = SAE.from_pretrained(
release="gemma-scope-2b-pt-res-canonical",
sae_id=layer_id,
device="cuda" if torch.cuda.is_available() else "cpu"
)
gemma_scope_layer = layer_num
return f"Loaded SAE for layer {layer_num}: {layer_id}"
except Exception as e:
return f"Error loading SAE: {str(e)}"
@spaces.GPU
def analyze_prompt_features(prompt, top_k=10):
"""Analyze which SAE features activate for a given prompt."""
global model, tokenizer, gemma_scope_sae
from transformers import AutoModelForCausalLM, AutoTokenizer
top_k = int(top_k)
# Load Gemma 2 model if needed
if model is None:
model = AutoModelForCausalLM.from_pretrained(
"stvlynn/Gemma-2-2b-Chinese-it",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it")
if gemma_scope_sae is None:
load_result = load_gemma_scope_sae()
if "Error" in load_result:
return load_result
zero = torch.Tensor([0]).cuda()
model.to(zero.device)
# Get model activations
inputs = tokenizer(prompt, return_tensors="pt").to(zero.device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Run through SAE
layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13
if layer_idx >= len(outputs.hidden_states):
layer_idx = len(outputs.hidden_states) - 1
hidden_state = outputs.hidden_states[layer_idx]
feature_acts = gemma_scope_sae.encode(hidden_state)
# Get top activated features
top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k)
# Build results with Neuronpedia links
layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12
neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k"
results = ["## Top Activated Features\n"]
results.append("| Feature | Activation | Neuronpedia Link |")
results.append("|---------|------------|------------------|")
for idx, val in zip(top_features.indices, top_features.values):
feature_id = idx.item()
activation = val.item()
link = f"{neuronpedia_base}/{feature_id}"
results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |")
results.append("")
results.append("---")
results.append("**How to use:** Click the links to see what concepts each feature represents.")
return "\n".join(results)
def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"):
"""Fetch feature data from Neuronpedia API."""
import requests
feature_id = int(feature_id)
layer = int(layer)
api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}"
try:
response = requests.get(api_url, timeout=10)
if response.status_code == 200:
data = response.json()
return format_neuronpedia_feature(data, feature_id, layer, width)
elif response.status_code == 404:
return f"Feature {feature_id} not found at layer {layer}"
else:
return f"API error: {response.status_code}"
except requests.exceptions.Timeout:
return "Request timed out - Neuronpedia may be slow"
except Exception as e:
return f"Error fetching feature: {str(e)}"
def format_neuronpedia_feature(data, feature_id, layer, width):
"""Format Neuronpedia feature data as markdown."""
results = []
results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)")
results.append("")
if data.get("description"):
results.append(f"**Description:** {data['description']}")
results.append("")
if data.get("explanations") and len(data["explanations"]) > 0:
explanation = data["explanations"][0].get("description", "")
if explanation:
results.append(f"**Auto-interpretation:** {explanation}")
results.append("")
if data.get("activations") and len(data["activations"]) > 0:
results.append("### Top Activating Examples")
results.append("")
for i, act in enumerate(data["activations"][:5]):
tokens = act.get("tokens", [])
values = act.get("values", [])
if tokens:
max_idx = values.index(max(values)) if values else 0
text_parts = []
for j, tok in enumerate(tokens):
if j == max_idx:
text_parts.append(f"**{tok}**")
else:
text_parts.append(tok)
text = "".join(text_parts)
results.append(f"{i+1}. {text}")
results.append("")
results.append("### Feature Stats")
results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`")
if data.get("max_activation"):
results.append(f"- **Max Activation:** {data['max_activation']:.2f}")
if data.get("frac_nonzero"):
results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%")
results.append("")
results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})")
return "\n".join(results)
# Build Gradio interface
with gr.Blocks(title="SAE Feature Analyzer") as demo:
gr.Markdown("# SAE Feature Analyzer")
gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)")
with gr.Tab("Analyze Prompt"):
prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3)
layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer")
topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features")
analyze_btn = gr.Button("Analyze Features", variant="primary")
analysis_output = gr.Markdown(label="Analysis Results")
analyze_btn.click(
fn=analyze_prompt_features,
inputs=[prompt_input, topk_slider],
outputs=[analysis_output]
)
with gr.Tab("Lookup Feature"):
feature_id_input = gr.Number(label="Feature ID", value=0)
layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer")
lookup_btn = gr.Button("Lookup Feature", variant="primary")
lookup_output = gr.Markdown(label="Feature Details")
lookup_btn.click(
fn=fetch_neuronpedia_feature,
inputs=[feature_id_input, layer_input],
outputs=[lookup_output]
)
demo.launch()