Spaces:
Sleeping
Sleeping
| """ | |
| Model Manager - TransformerLens wrapper | |
| Provides a singleton HookedTransformer instance for the backend. | |
| Supports GPT-2-small (ungated) and meta-llama/Llama-3.2-{1B,3B}-Instruct | |
| (gated; set HF_TOKEN in the environment to load them). | |
| """ | |
| import os | |
| from typing import Optional, Dict, Any | |
| import torch | |
| from transformer_lens import HookedTransformer | |
| # Global model instance (avoids reloading on each request) | |
| _model: Optional[HookedTransformer] = None | |
| _model_name: Optional[str] = None | |
| _hf_login_done: bool = False | |
| def get_device() -> str: | |
| """Detect best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def _ensure_hf_login() -> None: | |
| """ | |
| Log into HuggingFace once per process if HF_TOKEN is set. | |
| Required for gated models (Llama-3.2-*). No-op if already logged in or | |
| if HF_TOKEN is not set. If login fails (network, bad token), we let | |
| the subsequent from_pretrained raise — failing softly here would just | |
| surface a confusing 401 later. | |
| """ | |
| global _hf_login_done | |
| if _hf_login_done: | |
| return | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| return | |
| try: | |
| from huggingface_hub import login | |
| login(token=token, add_to_git_credential=False) | |
| _hf_login_done = True | |
| except Exception: | |
| # Don't block GPT-2 loads if login fails for any reason. The | |
| # downstream from_pretrained will raise if the model is gated. | |
| pass | |
| def load_model(name: str = "gpt2-small") -> Dict[str, Any]: | |
| """ | |
| Load a HookedTransformer model. | |
| Supported names: | |
| - "gpt2-small" (default, ungated) | |
| - "meta-llama/Llama-3.2-1B-Instruct" (gated; needs HF_TOKEN) | |
| - "meta-llama/Llama-3.2-3B-Instruct" (gated; needs HF_TOKEN) | |
| Returns the model's runtime config so the frontend can adapt layer | |
| dropdowns / hooks to the loaded model. | |
| """ | |
| global _model, _model_name | |
| if _model is not None and _model_name == name: | |
| return { | |
| "status": "already_loaded", | |
| "model_name": name, | |
| "n_layers": _model.cfg.n_layers, | |
| "d_model": _model.cfg.d_model, | |
| "n_heads": _model.cfg.n_heads, | |
| "d_vocab": _model.cfg.d_vocab, | |
| } | |
| # Gated models (Llama) require an HF token; log in once if present. | |
| _ensure_hf_login() | |
| device = get_device() | |
| _model = HookedTransformer.from_pretrained(name, device=device) | |
| _model_name = name | |
| return { | |
| "status": "loaded", | |
| "model_name": name, | |
| "device": device, | |
| "n_layers": _model.cfg.n_layers, | |
| "d_model": _model.cfg.d_model, | |
| "n_heads": _model.cfg.n_heads, | |
| "d_vocab": _model.cfg.d_vocab, | |
| } | |
| def get_model() -> HookedTransformer: | |
| """Get loaded model, raise if none.""" | |
| if _model is None: | |
| raise RuntimeError("No model loaded. Call load_model() first.") | |
| return _model | |
| def get_model_name() -> Optional[str]: | |
| """Return the name passed to the most recent load_model() call.""" | |
| return _model_name | |
| def run_with_cache(prompt: str): | |
| """ | |
| Run inference and capture all activations. | |
| Returns (tokens, logits, cache) where cache has every intermediate | |
| activation indexed by hook name, e.g.: | |
| cache["blocks.0.hook_resid_post"] — residual after layer 0 | |
| cache["blocks.0.attn.hook_pattern"] — attention pattern at layer 0 | |
| """ | |
| model = get_model() | |
| tokens = model.to_str_tokens(prompt) | |
| logits, cache = model.run_with_cache(prompt) | |
| return tokens, logits, cache | |