matrix-sae / loader.py
JackYoung27's picture
Upload folder using huggingface_hub
7771b1f verified
from __future__ import annotations
import json
from pathlib import Path
import torch
DEFAULT_REPO_ID = "jackyoung27/matrix-sae"
def _root(root=None):
return Path(root).resolve() if root else Path(__file__).resolve().parent
def load_manifest(root=None):
return json.loads((_root(root) / "manifest.json").read_text())
def _find(spec, m):
for e in m["checkpoints"]:
if e["relative_path"] == spec or e["tag"] == spec:
return e
raise KeyError(spec)
def list_checkpoints(root=None, group=None, sae_type=None):
es = load_manifest(root)["checkpoints"]
if group: es = [e for e in es if e["group"] == group]
if sae_type: es = [e for e in es if e["parsed"].get("sae_type") == sae_type]
return es
def load_checkpoint(spec, device="cpu", root=None):
r = _root(root)
e = _find(spec, load_manifest(r))
import sys
if str(r) not in sys.path: sys.path.insert(0, str(r))
from sae import build_sae_from_config
d = r / e["relative_path"]
cfg = json.loads((d / "config.json").read_text())
ckpt = torch.load(d / "best.pt", map_location="cpu", weights_only=False)
model = build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"])
model.load_state_dict(ckpt["model_state_dict"])
return model.to(device).eval(), cfg, e, ckpt
def load_from_hub(spec, repo_id=DEFAULT_REPO_ID, device="cpu", revision=None, cache_dir=None):
from huggingface_hub import hf_hub_download
def dl(f): return hf_hub_download(repo_id=repo_id, filename=f, revision=revision, cache_dir=str(cache_dir) if cache_dir else None)
m = json.loads(Path(dl("manifest.json")).read_text())
e = _find(spec, m)
rel = e["relative_path"]
cfg_p = dl(f"{rel}/config.json")
ckpt_p = dl(f"{rel}/best.pt")
sae_p = dl("sae.py")
import importlib.util, sys
sae_dir = str(Path(sae_p).parent)
if sae_dir not in sys.path: sys.path.insert(0, sae_dir)
s = importlib.util.spec_from_file_location("sae", sae_p)
assert s and s.loader
mod = importlib.util.module_from_spec(s)
s.loader.exec_module(mod)
cfg = json.loads(Path(cfg_p).read_text())
ckpt = torch.load(ckpt_p, map_location="cpu", weights_only=False)
model = mod.build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"])
model.load_state_dict(ckpt["model_state_dict"])
return model.to(device).eval(), cfg, e, ckpt