Spaces:
Sleeping
Sleeping
| """Feature browsing, search, and description loading. | |
| Provides utilities for exploring SAE features: loading pre-computed | |
| descriptions from Neuronpedia, searching features semantically, and | |
| finding features that activate on given text. | |
| """ | |
| import os | |
| import io | |
| import json | |
| import gzip | |
| from pathlib import Path | |
| from typing import Optional | |
| from dataclasses import dataclass | |
| import requests | |
| import torch | |
| import numpy as np | |
| from dotenv import load_dotenv, find_dotenv | |
| from config import MODEL_CONFIG | |
| load_dotenv(find_dotenv(), override=True) | |
| DESCRIPTIONS_CACHE_DIR = Path(__file__).parent / "cache" | |
| NEURONPEDIA_S3_BASE = "https://neuronpedia-datasets.s3.us-east-1.amazonaws.com/v1" | |
| NEURONPEDIA_API_BASE = "https://www.neuronpedia.org" | |
| class FeatureInfo: | |
| """Information about a single SAE feature.""" | |
| index: int | |
| description: str | |
| max_activation: float = 0.0 | |
| mean_activation: float = 0.0 | |
| top_examples: list[dict] = None | |
| def __post_init__(self): | |
| if self.top_examples is None: | |
| self.top_examples = [] | |
| def _cache_path(model_id: str, layer: int, sae_width: str) -> Path: | |
| """Return the path for a cached descriptions file.""" | |
| return DESCRIPTIONS_CACHE_DIR / f"{model_id}_layer{layer}_{sae_width}.json" | |
| class FeatureStore: | |
| """Manages feature descriptions and metadata. | |
| Loads pre-computed feature descriptions and provides search | |
| functionality for browsing features. | |
| """ | |
| def __init__(self): | |
| self.descriptions: dict[int, str] = {} | |
| self.feature_infos: dict[int, FeatureInfo] = {} | |
| self._loaded = False | |
| def load_descriptions(self, path: Optional[str] = None) -> None: | |
| """Load feature descriptions from a JSON file. | |
| Descriptions can come from Neuronpedia's auto-interpretability | |
| pipeline or be manually curated. | |
| Expected format: {"0": "description for feature 0", "1": "...", ...} | |
| """ | |
| if path and Path(path).exists(): | |
| with open(path) as f: | |
| data = json.load(f) | |
| self.descriptions = {int(k): v for k, v in data.items()} | |
| self._loaded = True | |
| print(f"Loaded {len(self.descriptions)} feature descriptions.") | |
| else: | |
| print( | |
| "No description file found. Features will be identified by index only. " | |
| "Run fetch_descriptions() to download from Neuronpedia." | |
| ) | |
| def load_or_fetch_descriptions( | |
| self, | |
| model_id: str = "gemma-2-2b", | |
| layer: int = 20, | |
| sae_width: str = "16k", | |
| limit: int = 0, | |
| ) -> None: | |
| """Load descriptions from cache, or fetch from Neuronpedia and cache them. | |
| This is the primary method called on startup. It checks for a local | |
| cache file first to avoid redundant API calls. | |
| """ | |
| cache_file = _cache_path(model_id, layer, sae_width) | |
| if cache_file.exists(): | |
| print(f"Loading cached descriptions from {cache_file}") | |
| self.load_descriptions(str(cache_file)) | |
| return | |
| print(f"No cached descriptions found. Fetching from Neuronpedia...") | |
| self.fetch_descriptions_from_neuronpedia( | |
| model_id=model_id, | |
| layer=layer, | |
| sae_width=sae_width, | |
| save_path=str(cache_file), | |
| limit=limit, | |
| ) | |
| def fetch_descriptions_from_neuronpedia( | |
| self, | |
| model_id: str = "gemma-2-2b", | |
| layer: int = 20, | |
| sae_width: str = "16k", | |
| save_path: Optional[str] = None, | |
| limit: int = 0, | |
| ) -> None: | |
| """Fetch feature descriptions from Neuronpedia's S3 bulk export. | |
| Downloads gzipped JSONL batch files from the S3 dataset bucket. | |
| Each line contains a JSON object with 'index' and 'description' fields. | |
| """ | |
| sae_source = f"{layer}-gemmascope-res-{sae_width}" | |
| s3_prefix = f"{NEURONPEDIA_S3_BASE}/{model_id}/{sae_source}/explanations" | |
| print(f"Fetching descriptions for {model_id}/{sae_source} from S3...") | |
| descriptions = {} | |
| # Download batch files from S3 (batch-0.jsonl.gz, batch-1.jsonl.gz, ...) | |
| batch_idx = 0 | |
| while True: | |
| url = f"{s3_prefix}/batch-{batch_idx}.jsonl.gz" | |
| try: | |
| resp = requests.get(url, timeout=30) | |
| if resp.status_code == 404: | |
| break | |
| if resp.status_code != 200: | |
| print(f" Warning: HTTP {resp.status_code} for batch-{batch_idx}") | |
| break | |
| with gzip.GzipFile(fileobj=io.BytesIO(resp.content)) as gz: | |
| for line in gz: | |
| entry = json.loads(line) | |
| idx = entry.get("index") | |
| desc = entry.get("description", "") | |
| if idx is not None and desc: | |
| descriptions[int(idx)] = desc | |
| print( | |
| f" Downloaded batch-{batch_idx} ({len(descriptions)} descriptions so far)" | |
| ) | |
| if limit and len(descriptions) >= limit: | |
| break | |
| except Exception as e: | |
| print(f" Warning: failed to fetch batch-{batch_idx}: {e}") | |
| break | |
| batch_idx += 1 | |
| if not descriptions: | |
| print( | |
| " Could not fetch descriptions from Neuronpedia. " | |
| "The Feature Browser search will use index-only labels." | |
| ) | |
| return | |
| self.descriptions = descriptions | |
| self._loaded = True | |
| print(f"Loaded {len(descriptions)} feature descriptions from Neuronpedia.") | |
| if save_path: | |
| save_dir = Path(save_path).parent | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, "w") as f: | |
| json.dump({str(k): v for k, v in descriptions.items()}, f, indent=2) | |
| print(f"Cached descriptions to {save_path}") | |
| def get_description(self, feature_idx: int) -> str: | |
| """Get the description for a feature, or a default string.""" | |
| return self.descriptions.get( | |
| feature_idx, f"Feature #{feature_idx} (no description available)" | |
| ) | |
| def search_features(self, query: str, top_k: int = 20) -> list[FeatureInfo]: | |
| """Search features by description text (simple substring matching).""" | |
| query_lower = query.lower() | |
| results = [] | |
| for idx, desc in self.descriptions.items(): | |
| if query_lower in desc.lower(): | |
| info = self.feature_infos.get( | |
| idx, FeatureInfo(index=idx, description=desc) | |
| ) | |
| results.append(info) | |
| # Sort by index (stable ordering) and return top_k | |
| results.sort(key=lambda f: f.index) | |
| return results[:top_k] | |
| def find_active_features( | |
| self, | |
| feature_acts: torch.Tensor, | |
| str_tokens: list[str], | |
| top_k: int = 20, | |
| ) -> list[dict]: | |
| """Find the most active features for a given input. | |
| Args: | |
| feature_acts: SAE feature activations [batch, pos, n_features] | |
| str_tokens: Token strings for the input | |
| top_k: Number of top features to return | |
| Returns: | |
| List of dicts with feature index, max activation, description, | |
| and per-token activations. | |
| """ | |
| # Max activation across all positions for each feature | |
| # Skip position 0 (BOS token) as it produces strong, input-independent activations | |
| max_acts_per_feature = feature_acts[0, 1:].max(dim=0).values # [n_features] | |
| # Get top-k features by maximum activation | |
| top_values, top_indices = max_acts_per_feature.topk(top_k) | |
| results = [] | |
| for val, idx in zip(top_values.tolist(), top_indices.tolist()): | |
| if val <= 0: | |
| continue | |
| # Per-token activations for this feature | |
| # Skip position 0 (BOS token) as it produces strong, input-independent activations | |
| per_token = feature_acts[0, 1:, idx].tolist() # Skip position 0 (BOS token) | |
| results.append( | |
| { | |
| "feature_idx": idx, | |
| "max_activation": val, | |
| "description": self.get_description(idx), | |
| "per_token_activations": per_token, | |
| "str_tokens": str_tokens[1:], | |
| } | |
| ) | |
| return results | |
| def get_feature_activation_histogram( | |
| self, | |
| feature_acts: torch.Tensor, | |
| feature_idx: int, | |
| ) -> dict: | |
| """Get activation distribution stats for a feature. | |
| Returns histogram data for plotting. | |
| """ | |
| # Per-token activations for this feature | |
| # Skip position 0 (BOS token) as it produces strong, input-independent activations | |
| acts = feature_acts[0, 1:, feature_idx].cpu().numpy() | |
| return { | |
| "activations": acts.tolist(), | |
| "mean": float(acts.mean()), | |
| "max": float(acts.max()), | |
| "sparsity": float((acts > 0).mean()), | |
| "n_active": int((acts > 0).sum()), | |
| "n_total": len(acts), | |
| } | |
| # Global feature store instance | |
| _store: Optional[FeatureStore] = None | |
| def get_feature_store(limit: int = 0) -> FeatureStore: | |
| """Get or create the global FeatureStore instance. | |
| Automatically loads/fetches descriptions from Neuronpedia on first call. | |
| """ | |
| global _store | |
| if _store is None: | |
| _store = FeatureStore() | |
| model_id = MODEL_CONFIG.model_name.replace("google/", "") | |
| _store.load_or_fetch_descriptions( | |
| model_id=model_id, | |
| layer=MODEL_CONFIG.default_layer, | |
| sae_width="16k", | |
| limit=limit, | |
| ) | |
| return _store | |
| if __name__ == "__main__": | |
| # Test out getting one feature description | |
| store = get_feature_store() | |
| print(store.get_description(0)) | |