from __future__ import annotations import json from pathlib import Path import torch DEFAULT_REPO_ID = "jackyoung27/matrix-sae" def _root(root=None): return Path(root).resolve() if root else Path(__file__).resolve().parent def load_manifest(root=None): return json.loads((_root(root) / "manifest.json").read_text()) def _find(spec, m): for e in m["checkpoints"]: if e["relative_path"] == spec or e["tag"] == spec: return e raise KeyError(spec) def list_checkpoints(root=None, group=None, sae_type=None): es = load_manifest(root)["checkpoints"] if group: es = [e for e in es if e["group"] == group] if sae_type: es = [e for e in es if e["parsed"].get("sae_type") == sae_type] return es def load_checkpoint(spec, device="cpu", root=None): r = _root(root) e = _find(spec, load_manifest(r)) import sys if str(r) not in sys.path: sys.path.insert(0, str(r)) from sae import build_sae_from_config d = r / e["relative_path"] cfg = json.loads((d / "config.json").read_text()) ckpt = torch.load(d / "best.pt", map_location="cpu", weights_only=False) model = build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"]) model.load_state_dict(ckpt["model_state_dict"]) return model.to(device).eval(), cfg, e, ckpt def load_from_hub(spec, repo_id=DEFAULT_REPO_ID, device="cpu", revision=None, cache_dir=None): from huggingface_hub import hf_hub_download def dl(f): return hf_hub_download(repo_id=repo_id, filename=f, revision=revision, cache_dir=str(cache_dir) if cache_dir else None) m = json.loads(Path(dl("manifest.json")).read_text()) e = _find(spec, m) rel = e["relative_path"] cfg_p = dl(f"{rel}/config.json") ckpt_p = dl(f"{rel}/best.pt") sae_p = dl("sae.py") import importlib.util, sys sae_dir = str(Path(sae_p).parent) if sae_dir not in sys.path: sys.path.insert(0, sae_dir) s = importlib.util.spec_from_file_location("sae", sae_p) assert s and s.loader mod = importlib.util.module_from_spec(s) s.loader.exec_module(mod) cfg = json.loads(Path(cfg_p).read_text()) ckpt = torch.load(ckpt_p, map_location="cpu", weights_only=False) model = mod.build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"]) model.load_state_dict(ckpt["model_state_dict"]) return model.to(device).eval(), cfg, e, ckpt