SteerTheShip / features.py
benbatman's picture
removing neuronpedia api key requirements
0764fc0
"""Feature browsing, search, and description loading.
Provides utilities for exploring SAE features: loading pre-computed
descriptions from Neuronpedia, searching features semantically, and
finding features that activate on given text.
"""
import os
import io
import json
import gzip
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
import requests
import torch
import numpy as np
from dotenv import load_dotenv, find_dotenv
from config import MODEL_CONFIG
load_dotenv(find_dotenv(), override=True)
DESCRIPTIONS_CACHE_DIR = Path(__file__).parent / "cache"
NEURONPEDIA_S3_BASE = "https://neuronpedia-datasets.s3.us-east-1.amazonaws.com/v1"
NEURONPEDIA_API_BASE = "https://www.neuronpedia.org"
@dataclass
class FeatureInfo:
"""Information about a single SAE feature."""
index: int
description: str
max_activation: float = 0.0
mean_activation: float = 0.0
top_examples: list[dict] = None
def __post_init__(self):
if self.top_examples is None:
self.top_examples = []
def _cache_path(model_id: str, layer: int, sae_width: str) -> Path:
"""Return the path for a cached descriptions file."""
return DESCRIPTIONS_CACHE_DIR / f"{model_id}_layer{layer}_{sae_width}.json"
class FeatureStore:
"""Manages feature descriptions and metadata.
Loads pre-computed feature descriptions and provides search
functionality for browsing features.
"""
def __init__(self):
self.descriptions: dict[int, str] = {}
self.feature_infos: dict[int, FeatureInfo] = {}
self._loaded = False
def load_descriptions(self, path: Optional[str] = None) -> None:
"""Load feature descriptions from a JSON file.
Descriptions can come from Neuronpedia's auto-interpretability
pipeline or be manually curated.
Expected format: {"0": "description for feature 0", "1": "...", ...}
"""
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
self.descriptions = {int(k): v for k, v in data.items()}
self._loaded = True
print(f"Loaded {len(self.descriptions)} feature descriptions.")
else:
print(
"No description file found. Features will be identified by index only. "
"Run fetch_descriptions() to download from Neuronpedia."
)
def load_or_fetch_descriptions(
self,
model_id: str = "gemma-2-2b",
layer: int = 20,
sae_width: str = "16k",
limit: int = 0,
) -> None:
"""Load descriptions from cache, or fetch from Neuronpedia and cache them.
This is the primary method called on startup. It checks for a local
cache file first to avoid redundant API calls.
"""
cache_file = _cache_path(model_id, layer, sae_width)
if cache_file.exists():
print(f"Loading cached descriptions from {cache_file}")
self.load_descriptions(str(cache_file))
return
print(f"No cached descriptions found. Fetching from Neuronpedia...")
self.fetch_descriptions_from_neuronpedia(
model_id=model_id,
layer=layer,
sae_width=sae_width,
save_path=str(cache_file),
limit=limit,
)
def fetch_descriptions_from_neuronpedia(
self,
model_id: str = "gemma-2-2b",
layer: int = 20,
sae_width: str = "16k",
save_path: Optional[str] = None,
limit: int = 0,
) -> None:
"""Fetch feature descriptions from Neuronpedia's S3 bulk export.
Downloads gzipped JSONL batch files from the S3 dataset bucket.
Each line contains a JSON object with 'index' and 'description' fields.
"""
sae_source = f"{layer}-gemmascope-res-{sae_width}"
s3_prefix = f"{NEURONPEDIA_S3_BASE}/{model_id}/{sae_source}/explanations"
print(f"Fetching descriptions for {model_id}/{sae_source} from S3...")
descriptions = {}
# Download batch files from S3 (batch-0.jsonl.gz, batch-1.jsonl.gz, ...)
batch_idx = 0
while True:
url = f"{s3_prefix}/batch-{batch_idx}.jsonl.gz"
try:
resp = requests.get(url, timeout=30)
if resp.status_code == 404:
break
if resp.status_code != 200:
print(f" Warning: HTTP {resp.status_code} for batch-{batch_idx}")
break
with gzip.GzipFile(fileobj=io.BytesIO(resp.content)) as gz:
for line in gz:
entry = json.loads(line)
idx = entry.get("index")
desc = entry.get("description", "")
if idx is not None and desc:
descriptions[int(idx)] = desc
print(
f" Downloaded batch-{batch_idx} ({len(descriptions)} descriptions so far)"
)
if limit and len(descriptions) >= limit:
break
except Exception as e:
print(f" Warning: failed to fetch batch-{batch_idx}: {e}")
break
batch_idx += 1
if not descriptions:
print(
" Could not fetch descriptions from Neuronpedia. "
"The Feature Browser search will use index-only labels."
)
return
self.descriptions = descriptions
self._loaded = True
print(f"Loaded {len(descriptions)} feature descriptions from Neuronpedia.")
if save_path:
save_dir = Path(save_path).parent
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_path, "w") as f:
json.dump({str(k): v for k, v in descriptions.items()}, f, indent=2)
print(f"Cached descriptions to {save_path}")
def get_description(self, feature_idx: int) -> str:
"""Get the description for a feature, or a default string."""
return self.descriptions.get(
feature_idx, f"Feature #{feature_idx} (no description available)"
)
def search_features(self, query: str, top_k: int = 20) -> list[FeatureInfo]:
"""Search features by description text (simple substring matching)."""
query_lower = query.lower()
results = []
for idx, desc in self.descriptions.items():
if query_lower in desc.lower():
info = self.feature_infos.get(
idx, FeatureInfo(index=idx, description=desc)
)
results.append(info)
# Sort by index (stable ordering) and return top_k
results.sort(key=lambda f: f.index)
return results[:top_k]
def find_active_features(
self,
feature_acts: torch.Tensor,
str_tokens: list[str],
top_k: int = 20,
) -> list[dict]:
"""Find the most active features for a given input.
Args:
feature_acts: SAE feature activations [batch, pos, n_features]
str_tokens: Token strings for the input
top_k: Number of top features to return
Returns:
List of dicts with feature index, max activation, description,
and per-token activations.
"""
# Max activation across all positions for each feature
# Skip position 0 (BOS token) as it produces strong, input-independent activations
max_acts_per_feature = feature_acts[0, 1:].max(dim=0).values # [n_features]
# Get top-k features by maximum activation
top_values, top_indices = max_acts_per_feature.topk(top_k)
results = []
for val, idx in zip(top_values.tolist(), top_indices.tolist()):
if val <= 0:
continue
# Per-token activations for this feature
# Skip position 0 (BOS token) as it produces strong, input-independent activations
per_token = feature_acts[0, 1:, idx].tolist() # Skip position 0 (BOS token)
results.append(
{
"feature_idx": idx,
"max_activation": val,
"description": self.get_description(idx),
"per_token_activations": per_token,
"str_tokens": str_tokens[1:],
}
)
return results
def get_feature_activation_histogram(
self,
feature_acts: torch.Tensor,
feature_idx: int,
) -> dict:
"""Get activation distribution stats for a feature.
Returns histogram data for plotting.
"""
# Per-token activations for this feature
# Skip position 0 (BOS token) as it produces strong, input-independent activations
acts = feature_acts[0, 1:, feature_idx].cpu().numpy()
return {
"activations": acts.tolist(),
"mean": float(acts.mean()),
"max": float(acts.max()),
"sparsity": float((acts > 0).mean()),
"n_active": int((acts > 0).sum()),
"n_total": len(acts),
}
# Global feature store instance
_store: Optional[FeatureStore] = None
def get_feature_store(limit: int = 0) -> FeatureStore:
"""Get or create the global FeatureStore instance.
Automatically loads/fetches descriptions from Neuronpedia on first call.
"""
global _store
if _store is None:
_store = FeatureStore()
model_id = MODEL_CONFIG.model_name.replace("google/", "")
_store.load_or_fetch_descriptions(
model_id=model_id,
layer=MODEL_CONFIG.default_layer,
sae_width="16k",
limit=limit,
)
return _store
if __name__ == "__main__":
# Test out getting one feature description
store = get_feature_store()
print(store.get_description(0))