Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| # Global SAE state | |
| gemma_scope_sae = None | |
| gemma_scope_layer = None | |
| model = None | |
| tokenizer = None | |
| def load_gemma_scope_sae(layer_num=12): | |
| """Load Gemma Scope SAE for a specific layer.""" | |
| global gemma_scope_sae, gemma_scope_layer | |
| from sae_lens import SAE | |
| layer_id = f"layer_{layer_num}/width_16k/canonical" | |
| try: | |
| gemma_scope_sae = SAE.from_pretrained( | |
| release="gemma-scope-2b-pt-res-canonical", | |
| sae_id=layer_id, | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| gemma_scope_layer = layer_num | |
| return f"Loaded SAE for layer {layer_num}: {layer_id}" | |
| except Exception as e: | |
| return f"Error loading SAE: {str(e)}" | |
| def analyze_prompt_features(prompt, top_k=10): | |
| """Analyze which SAE features activate for a given prompt.""" | |
| global model, tokenizer, gemma_scope_sae | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| top_k = int(top_k) | |
| # Load Gemma 2 model if needed | |
| if model is None: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "stvlynn/Gemma-2-2b-Chinese-it", | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("stvlynn/Gemma-2-2b-Chinese-it") | |
| if gemma_scope_sae is None: | |
| load_result = load_gemma_scope_sae() | |
| if "Error" in load_result: | |
| return load_result | |
| zero = torch.Tensor([0]).cuda() | |
| model.to(zero.device) | |
| # Get model activations | |
| inputs = tokenizer(prompt, return_tensors="pt").to(zero.device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| # Run through SAE | |
| layer_idx = gemma_scope_layer + 1 if gemma_scope_layer is not None else 13 | |
| if layer_idx >= len(outputs.hidden_states): | |
| layer_idx = len(outputs.hidden_states) - 1 | |
| hidden_state = outputs.hidden_states[layer_idx] | |
| feature_acts = gemma_scope_sae.encode(hidden_state) | |
| # Get top activated features | |
| top_features = torch.topk(feature_acts.mean(dim=1).squeeze(), top_k) | |
| # Build results with Neuronpedia links | |
| layer_num = gemma_scope_layer if gemma_scope_layer is not None else 12 | |
| neuronpedia_base = f"https://www.neuronpedia.org/gemma-2-2b/{layer_num}-gemmascope-res-16k" | |
| results = ["## Top Activated Features\n"] | |
| results.append("| Feature | Activation | Neuronpedia Link |") | |
| results.append("|---------|------------|------------------|") | |
| for idx, val in zip(top_features.indices, top_features.values): | |
| feature_id = idx.item() | |
| activation = val.item() | |
| link = f"{neuronpedia_base}/{feature_id}" | |
| results.append(f"| {feature_id:5d} | {activation:8.2f} | [View Feature]({link}) |") | |
| results.append("") | |
| results.append("---") | |
| results.append("**How to use:** Click the links to see what concepts each feature represents.") | |
| return "\n".join(results) | |
| def fetch_neuronpedia_feature(feature_id, layer=12, width="16k"): | |
| """Fetch feature data from Neuronpedia API.""" | |
| import requests | |
| feature_id = int(feature_id) | |
| layer = int(layer) | |
| api_url = f"https://www.neuronpedia.org/api/feature/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id}" | |
| try: | |
| response = requests.get(api_url, timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return format_neuronpedia_feature(data, feature_id, layer, width) | |
| elif response.status_code == 404: | |
| return f"Feature {feature_id} not found at layer {layer}" | |
| else: | |
| return f"API error: {response.status_code}" | |
| except requests.exceptions.Timeout: | |
| return "Request timed out - Neuronpedia may be slow" | |
| except Exception as e: | |
| return f"Error fetching feature: {str(e)}" | |
| def format_neuronpedia_feature(data, feature_id, layer, width): | |
| """Format Neuronpedia feature data as markdown.""" | |
| results = [] | |
| results.append(f"## Feature {feature_id} (Layer {layer}, {width} width)") | |
| results.append("") | |
| if data.get("description"): | |
| results.append(f"**Description:** {data['description']}") | |
| results.append("") | |
| if data.get("explanations") and len(data["explanations"]) > 0: | |
| explanation = data["explanations"][0].get("description", "") | |
| if explanation: | |
| results.append(f"**Auto-interpretation:** {explanation}") | |
| results.append("") | |
| if data.get("activations") and len(data["activations"]) > 0: | |
| results.append("### Top Activating Examples") | |
| results.append("") | |
| for i, act in enumerate(data["activations"][:5]): | |
| tokens = act.get("tokens", []) | |
| values = act.get("values", []) | |
| if tokens: | |
| max_idx = values.index(max(values)) if values else 0 | |
| text_parts = [] | |
| for j, tok in enumerate(tokens): | |
| if j == max_idx: | |
| text_parts.append(f"**{tok}**") | |
| else: | |
| text_parts.append(tok) | |
| text = "".join(text_parts) | |
| results.append(f"{i+1}. {text}") | |
| results.append("") | |
| results.append("### Feature Stats") | |
| results.append(f"- **Neuronpedia ID:** `gemma-2-2b_{layer}-gemmascope-res-{width}_{feature_id}`") | |
| if data.get("max_activation"): | |
| results.append(f"- **Max Activation:** {data['max_activation']:.2f}") | |
| if data.get("frac_nonzero"): | |
| results.append(f"- **Activation Frequency:** {data['frac_nonzero']*100:.2f}%") | |
| results.append("") | |
| results.append(f"[View on Neuronpedia](https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-res-{width}/{feature_id})") | |
| return "\n".join(results) | |
| # Build Gradio interface | |
| with gr.Blocks(title="SAE Feature Analyzer") as demo: | |
| gr.Markdown("# SAE Feature Analyzer") | |
| gr.Markdown("Analyze neural network features using Sparse Autoencoders (Gemma Scope)") | |
| with gr.Tab("Analyze Prompt"): | |
| prompt_input = gr.Textbox(label="Prompt to Analyze", lines=3) | |
| layer_slider = gr.Slider(0, 25, value=12, step=1, label="SAE Layer") | |
| topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top K Features") | |
| analyze_btn = gr.Button("Analyze Features", variant="primary") | |
| analysis_output = gr.Markdown(label="Analysis Results") | |
| analyze_btn.click( | |
| fn=analyze_prompt_features, | |
| inputs=[prompt_input, topk_slider], | |
| outputs=[analysis_output] | |
| ) | |
| with gr.Tab("Lookup Feature"): | |
| feature_id_input = gr.Number(label="Feature ID", value=0) | |
| layer_input = gr.Slider(0, 25, value=12, step=1, label="Layer") | |
| lookup_btn = gr.Button("Lookup Feature", variant="primary") | |
| lookup_output = gr.Markdown(label="Feature Details") | |
| lookup_btn.click( | |
| fn=fetch_neuronpedia_feature, | |
| inputs=[feature_id_input, layer_input], | |
| outputs=[lookup_output] | |
| ) | |
| demo.launch() |