File size: 2,386 Bytes
7771b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations
import json
from pathlib import Path
import torch

DEFAULT_REPO_ID = "jackyoung27/matrix-sae"


def _root(root=None):
    return Path(root).resolve() if root else Path(__file__).resolve().parent


def load_manifest(root=None):
    return json.loads((_root(root) / "manifest.json").read_text())


def _find(spec, m):
    for e in m["checkpoints"]:
        if e["relative_path"] == spec or e["tag"] == spec:
            return e
    raise KeyError(spec)


def list_checkpoints(root=None, group=None, sae_type=None):
    es = load_manifest(root)["checkpoints"]
    if group: es = [e for e in es if e["group"] == group]
    if sae_type: es = [e for e in es if e["parsed"].get("sae_type") == sae_type]
    return es


def load_checkpoint(spec, device="cpu", root=None):
    r = _root(root)
    e = _find(spec, load_manifest(r))
    import sys
    if str(r) not in sys.path: sys.path.insert(0, str(r))
    from sae import build_sae_from_config
    d = r / e["relative_path"]
    cfg = json.loads((d / "config.json").read_text())
    ckpt = torch.load(d / "best.pt", map_location="cpu", weights_only=False)
    model = build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"])
    model.load_state_dict(ckpt["model_state_dict"])
    return model.to(device).eval(), cfg, e, ckpt


def load_from_hub(spec, repo_id=DEFAULT_REPO_ID, device="cpu", revision=None, cache_dir=None):
    from huggingface_hub import hf_hub_download
    def dl(f): return hf_hub_download(repo_id=repo_id, filename=f, revision=revision, cache_dir=str(cache_dir) if cache_dir else None)
    m = json.loads(Path(dl("manifest.json")).read_text())
    e = _find(spec, m)
    rel = e["relative_path"]
    cfg_p = dl(f"{rel}/config.json")
    ckpt_p = dl(f"{rel}/best.pt")
    sae_p = dl("sae.py")
    import importlib.util, sys
    sae_dir = str(Path(sae_p).parent)
    if sae_dir not in sys.path: sys.path.insert(0, sae_dir)
    s = importlib.util.spec_from_file_location("sae", sae_p)
    assert s and s.loader
    mod = importlib.util.module_from_spec(s)
    s.loader.exec_module(mod)
    cfg = json.loads(Path(cfg_p).read_text())
    ckpt = torch.load(ckpt_p, map_location="cpu", weights_only=False)
    model = mod.build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"])
    model.load_state_dict(ckpt["model_state_dict"])
    return model.to(device).eval(), cfg, e, ckpt