neuroscope-api / model.py
lymnal's picture
sync: Wave 1+2+3 backend + 6 techniques + populated refusal/over-refusal data
ffb6dc9 verified
Raw
History Blame Contribute Delete
3.64 kB
"""
Model Manager - TransformerLens wrapper
Provides a singleton HookedTransformer instance for the backend.
Supports GPT-2-small (ungated) and meta-llama/Llama-3.2-{1B,3B}-Instruct
(gated; set HF_TOKEN in the environment to load them).
"""
import os
from typing import Optional, Dict, Any
import torch
from transformer_lens import HookedTransformer
# Global model instance (avoids reloading on each request)
_model: Optional[HookedTransformer] = None
_model_name: Optional[str] = None
_hf_login_done: bool = False
def get_device() -> str:
"""Detect best available device."""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def _ensure_hf_login() -> None:
"""
Log into HuggingFace once per process if HF_TOKEN is set.
Required for gated models (Llama-3.2-*). No-op if already logged in or
if HF_TOKEN is not set. If login fails (network, bad token), we let
the subsequent from_pretrained raise — failing softly here would just
surface a confusing 401 later.
"""
global _hf_login_done
if _hf_login_done:
return
token = os.environ.get("HF_TOKEN")
if not token:
return
try:
from huggingface_hub import login
login(token=token, add_to_git_credential=False)
_hf_login_done = True
except Exception:
# Don't block GPT-2 loads if login fails for any reason. The
# downstream from_pretrained will raise if the model is gated.
pass
def load_model(name: str = "gpt2-small") -> Dict[str, Any]:
"""
Load a HookedTransformer model.
Supported names:
- "gpt2-small" (default, ungated)
- "meta-llama/Llama-3.2-1B-Instruct" (gated; needs HF_TOKEN)
- "meta-llama/Llama-3.2-3B-Instruct" (gated; needs HF_TOKEN)
Returns the model's runtime config so the frontend can adapt layer
dropdowns / hooks to the loaded model.
"""
global _model, _model_name
if _model is not None and _model_name == name:
return {
"status": "already_loaded",
"model_name": name,
"n_layers": _model.cfg.n_layers,
"d_model": _model.cfg.d_model,
"n_heads": _model.cfg.n_heads,
"d_vocab": _model.cfg.d_vocab,
}
# Gated models (Llama) require an HF token; log in once if present.
_ensure_hf_login()
device = get_device()
_model = HookedTransformer.from_pretrained(name, device=device)
_model_name = name
return {
"status": "loaded",
"model_name": name,
"device": device,
"n_layers": _model.cfg.n_layers,
"d_model": _model.cfg.d_model,
"n_heads": _model.cfg.n_heads,
"d_vocab": _model.cfg.d_vocab,
}
def get_model() -> HookedTransformer:
"""Get loaded model, raise if none."""
if _model is None:
raise RuntimeError("No model loaded. Call load_model() first.")
return _model
def get_model_name() -> Optional[str]:
"""Return the name passed to the most recent load_model() call."""
return _model_name
def run_with_cache(prompt: str):
"""
Run inference and capture all activations.
Returns (tokens, logits, cache) where cache has every intermediate
activation indexed by hook name, e.g.:
cache["blocks.0.hook_resid_post"] — residual after layer 0
cache["blocks.0.attn.hook_pattern"] — attention pattern at layer 0
"""
model = get_model()
tokens = model.to_str_tokens(prompt)
logits, cache = model.run_with_cache(prompt)
return tokens, logits, cache