import gradio as gr import spaces import torch # Global SAE state gemma_scope_sae = None gemma_scope_layer = None model = None tokenizer = None def load_gemma_scope_sae(layer_num=12): """Load Gemma Scope SAE for a specific layer.""" global gemma_scope_sae, gemma_scope_layer from sae_lens import SAE layer_id = f"layer_{layer_num}/width_16k/canonical" try: gemma_scope_sae = SAE.from_pretrained( release="gemma-scope-2b-pt-res-canonical", sae_id=layer_id, device="cuda" if torch.cuda.is_available() else "cpu" ) gemma_scope_layer = layer_num return f"Loaded SAE for layer {layer_num}: {layer_id}" except Exception as e: return f"Error loading SAE: {str(e)}" @spaces.GPU def analyze_prompt_features(prompt, top_k=10): """Analyze which SAE features activate for a given prompt.""" global model, tokenizer, gemma_scope_sae from transformers import AutoModelForCausalLM, AutoTokenizer top_k = int(top_k) # Load Gemma 2 model if needed if model is None: model = AutoModelForCausalLM.from_pretrained( "stvlynn/Gemma-2-2b-Chinese-it", torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it") if gemma_scope_sae is None: load_result = load_gemma_scope_sae() if "Error" in load_result: return load_result zero = torch.Tensor([0]).cuda() model.to(zero.device) # Get model activations inputs = tokenizer(prompt, return_tensors="pt").to(zero.device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) # Run through SAE layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13 if layer_idx >= len(outputs.hidden_states): layer_idx = len(outputs.hidden_states) - 1 hidden_state = outputs.hidden_states[layer_idx] feature_acts = gemma_scope_sae.encode(hidden_state) # Get top activated features top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k) # Build results with Neuronpedia links layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12 neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k" results = ["## Top Activated Features\n"] results.append("| Feature | Activation | Neuronpedia Link |") results.append("|---------|------------|------------------|") for idx, val in zip(top_features.indices, top_features.values): feature_id = idx.item() activation = val.item() link = f"{neuronpedia_base}/{feature_id}" results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |") results.append("") results.append("---") results.append("**How to use:** Click the links to see what concepts each feature represents.") return "\n".join(results) def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"): """Fetch feature data from Neuronpedia API.""" import requests feature_id = int(feature_id) layer = int(layer) api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}" try: response = requests.get(api_url, timeout=10) if response.status_code == 200: data = response.json() return format_neuronpedia_feature(data, feature_id, layer, width) elif response.status_code == 404: return f"Feature {feature_id} not found at layer {layer}" else: return f"API error: {response.status_code}" except requests.exceptions.Timeout: return "Request timed out - Neuronpedia may be slow" except Exception as e: return f"Error fetching feature: {str(e)}" def format_neuronpedia_feature(data, feature_id, layer, width): """Format Neuronpedia feature data as markdown.""" results = [] results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)") results.append("") if data.get("description"): results.append(f"**Description:** {data['description']}") results.append("") if data.get("explanations") and len(data["explanations"]) > 0: explanation = data["explanations"][0].get("description", "") if explanation: results.append(f"**Auto-interpretation:** {explanation}") results.append("") if data.get("activations") and len(data["activations"]) > 0: results.append("### Top Activating Examples") results.append("") for i, act in enumerate(data["activations"][:5]): tokens = act.get("tokens", []) values = act.get("values", []) if tokens: max_idx = values.index(max(values)) if values else 0 text_parts = [] for j, tok in enumerate(tokens): if j == max_idx: text_parts.append(f"**{tok}**") else: text_parts.append(tok) text = "".join(text_parts) results.append(f"{i+1}. {text}") results.append("") results.append("### Feature Stats") results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`") if data.get("max_activation"): results.append(f"- **Max Activation:** {data['max_activation']:.2f}") if data.get("frac_nonzero"): results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%") results.append("") results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})") return "\n".join(results) # Build Gradio interface with gr.Blocks(title="SAE Feature Analyzer") as demo: gr.Markdown("# SAE Feature Analyzer") gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)") with gr.Tab("Analyze Prompt"): prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3) layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer") topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features") analyze_btn = gr.Button("Analyze Features", variant="primary") analysis_output = gr.Markdown(label="Analysis Results") analyze_btn.click( fn=analyze_prompt_features, inputs=[prompt_input, topk_slider], outputs=[analysis_output] ) with gr.Tab("Lookup Feature"): feature_id_input = gr.Number(label="Feature ID", value=0) layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer") lookup_btn = gr.Button("Lookup Feature", variant="primary") lookup_output = gr.Markdown(label="Feature Details") lookup_btn.click( fn=fetch_neuronpedia_feature, inputs=[feature_id_input, layer_input], outputs=[lookup_output] ) demo.launch()