Create codebook_contributions.py
Browse files- codebook_contributions.py +215 -0
codebook_contributions.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
battery_ablation.py β test contribution signals across batteries.
|
| 3 |
+
|
| 4 |
+
For each battery: load it frozen, extract its projective codebook, compute the
|
| 5 |
+
contribution signals (codebook_contributions), and pull its recon MSE as the
|
| 6 |
+
target. Then rank every signal by:
|
| 7 |
+
* std across batteries β does it vary at all, or is it a dead signal?
|
| 8 |
+
* |corr| with recon MSE β does it track downstream quality?
|
| 9 |
+
|
| 10 |
+
This is the "run N trains, test each contribution as a whole" pass: each
|
| 11 |
+
battery is one data point; the ablation table says which contributions earn a
|
| 12 |
+
slot in the omega-phase classifier before we hardwire any of them.
|
| 13 |
+
|
| 14 |
+
Cell workflow: paste codebook_contributions cell first, then this. Edit
|
| 15 |
+
BATTERIES to your set (β₯3 needed for correlation). `pip install ripser` for the
|
| 16 |
+
H1/H2 void signals; without it they self-exclude as NaN.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Any, Dict, List, Optional
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
# cell-tolerant: from the codebook_contributions cell (or installed)
|
| 25 |
+
try:
|
| 26 |
+
from codebook_contributions import (
|
| 27 |
+
collect_signatures, ablation_table, SIGNAL_SPECS, HAVE_RIPSER,
|
| 28 |
+
)
|
| 29 |
+
except ModuleNotFoundError:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ββ edit this to your battery set βββββββββββββββββββββββββββββββββββ
|
| 34 |
+
BATTERIES: List[str] = [
|
| 35 |
+
"h2_linear_tiny_imagenet_64",
|
| 36 |
+
# add your other battery folder names here, e.g.:
|
| 37 |
+
# "h2_linear_imagenet_128",
|
| 38 |
+
# "byte_trigram_proto_64_patch_2_v1",
|
| 39 |
+
# "v40_freckles_noise", "v50_fresnel_64", ...
|
| 40 |
+
]
|
| 41 |
+
REPO_ID = "AbstractPhil/geolip-SVAE"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def discover_batteries(repo_id: str = REPO_ID) -> List[str]:
|
| 45 |
+
"""List every battery folder in the repo that has a checkpoints/best.pt.
|
| 46 |
+
Saves you maintaining BATTERIES by hand β `run_ablation(discover_batteries())`
|
| 47 |
+
ablates over the whole zoo (mixed classes/D are fine; signals are D-normalized)."""
|
| 48 |
+
from huggingface_hub import HfApi
|
| 49 |
+
files = HfApi().list_repo_files(repo_id)
|
| 50 |
+
vers = sorted({f.split("/")[0] for f in files if f.endswith("/checkpoints/best.pt")})
|
| 51 |
+
print(f" discovered {len(vers)} batteries in {repo_id}")
|
| 52 |
+
return vers
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _load_model_safe(ver: str, device: str, repo_id: str):
|
| 56 |
+
"""load_model, with a fallback for torch.compile checkpoints whose state-dict
|
| 57 |
+
keys carry an '_orig_mod.' prefix. On that specific failure: re-download, strip
|
| 58 |
+
the prefix (and backfill config from final_report.json the way load_model would,
|
| 59 |
+
since checkpoint_path loads skip hf_version backfill), re-save, re-enter via
|
| 60 |
+
checkpoint_path so all of load_model's construction logic is reused."""
|
| 61 |
+
from geolip_svae.inference.loading import load_model
|
| 62 |
+
try:
|
| 63 |
+
return load_model(hf_version=ver, device=device, repo_id=repo_id)
|
| 64 |
+
except RuntimeError as e:
|
| 65 |
+
if "_orig_mod." not in str(e):
|
| 66 |
+
raise
|
| 67 |
+
import torch, os, tempfile, json
|
| 68 |
+
from huggingface_hub import hf_hub_download
|
| 69 |
+
path = hf_hub_download(repo_id=repo_id, filename=f"{ver}/checkpoints/best.pt",
|
| 70 |
+
repo_type="model")
|
| 71 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 72 |
+
pref = "_orig_mod."
|
| 73 |
+
ckpt["model_state_dict"] = {
|
| 74 |
+
(k[len(pref):] if k.startswith(pref) else k): v
|
| 75 |
+
for k, v in ckpt["model_state_dict"].items()
|
| 76 |
+
}
|
| 77 |
+
# mirror load_model's final_report backfill into the temp config
|
| 78 |
+
cfg0 = dict(ckpt.get("config", {}))
|
| 79 |
+
backfillable = ("n_heads", "smooth_mid", "linear_readout",
|
| 80 |
+
"svd_mode", "match_params", "channels")
|
| 81 |
+
if any(k not in cfg0 for k in backfillable):
|
| 82 |
+
try:
|
| 83 |
+
rp = hf_hub_download(repo_id=repo_id, filename=f"{ver}/final_report.json",
|
| 84 |
+
repo_type="model")
|
| 85 |
+
rc = json.load(open(rp)).get("config", {})
|
| 86 |
+
for k in backfillable:
|
| 87 |
+
if k not in cfg0 and rc.get(k) is not None:
|
| 88 |
+
cfg0[k] = rc[k]
|
| 89 |
+
ckpt["config"] = cfg0
|
| 90 |
+
except Exception:
|
| 91 |
+
pass
|
| 92 |
+
tmp = os.path.join(tempfile.gettempdir(), f"{ver.replace('/', '_')}_stripped.pt")
|
| 93 |
+
torch.save(ckpt, tmp)
|
| 94 |
+
model, cfg = load_model(checkpoint_path=tmp, device=device, repo_id=repo_id)
|
| 95 |
+
print(f" (recovered {ver}: stripped _orig_mod. torch.compile prefix)")
|
| 96 |
+
return model, cfg
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def extract_row(ver: str, device: str) -> Dict[str, Any]:
|
| 100 |
+
"""Load a frozen battery, extract its codebook, return an ablation row
|
| 101 |
+
{id, axes, D, n_pairs, n_unpaired, target=recon_mse, class}."""
|
| 102 |
+
from geolip_svae.inference.calibration import make_calibration
|
| 103 |
+
from geolip_svae.inference.codebook import extract_codebook
|
| 104 |
+
from geolip_svae.inference.train_codebook import (
|
| 105 |
+
infer_class_from_cfg, DEFAULT_CALIBRATIONS,
|
| 106 |
+
)
|
| 107 |
+
import torch
|
| 108 |
+
|
| 109 |
+
model, cfg = _load_model_safe(ver, device, REPO_ID)
|
| 110 |
+
cls = infer_class_from_cfg(cfg)
|
| 111 |
+
cal = DEFAULT_CALIBRATIONS.get(cls, DEFAULT_CALIBRATIONS["unknown"])
|
| 112 |
+
size = cfg.get("img_size") or cal["size"]
|
| 113 |
+
|
| 114 |
+
calib = make_calibration(cal["name"], n=cal["n"], size=size)
|
| 115 |
+
if not isinstance(calib, torch.Tensor):
|
| 116 |
+
calib = torch.as_tensor(calib)
|
| 117 |
+
ch = int(cfg.get("channels", 3)) # match model input channels
|
| 118 |
+
if calib.shape[1] != ch:
|
| 119 |
+
if ch < calib.shape[1]:
|
| 120 |
+
calib = calib[:, :ch]
|
| 121 |
+
else:
|
| 122 |
+
reps = (ch + calib.shape[1] - 1) // calib.shape[1]
|
| 123 |
+
calib = calib.repeat(1, reps, 1, 1)[:, :ch]
|
| 124 |
+
|
| 125 |
+
cb = extract_codebook(model, calib.to(device), model_id=ver,
|
| 126 |
+
model_class=cls, calibration_name=cal["name"])
|
| 127 |
+
axes = cb.axes.detach().cpu().numpy()
|
| 128 |
+
n_pairs = getattr(cb.metadata, "n_pairs", None)
|
| 129 |
+
n_unpaired = getattr(cb.metadata, "n_unpaired", None)
|
| 130 |
+
if n_pairs is None:
|
| 131 |
+
n_pairs, n_unpaired = len(cb.pairs), len(cb.unpaired)
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"id": ver,
|
| 135 |
+
"class": cls,
|
| 136 |
+
"axes": axes,
|
| 137 |
+
"D": int(cfg.get("D") or axes.shape[1]),
|
| 138 |
+
"n_pairs": int(n_pairs),
|
| 139 |
+
"n_unpaired": int(n_unpaired),
|
| 140 |
+
"target": cfg.get("_test_mse"), # recon MSE (None if absent)
|
| 141 |
+
"n_axes": int(axes.shape[0]),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def run_ablation(batteries: Optional[List[str]] = None, device: Optional[str] = None,
|
| 146 |
+
enabled=None) -> Dict[str, Any]:
|
| 147 |
+
"""Extract every battery's codebook, compute signatures, rank contributions."""
|
| 148 |
+
import torch
|
| 149 |
+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 150 |
+
batteries = batteries or BATTERIES
|
| 151 |
+
print(f"[battery_ablation] {len(batteries)} batteries on {device} | ripser={HAVE_RIPSER}")
|
| 152 |
+
|
| 153 |
+
cb_rows: List[Dict[str, Any]] = []
|
| 154 |
+
for ver in batteries:
|
| 155 |
+
try:
|
| 156 |
+
row = extract_row(ver, device)
|
| 157 |
+
cb_rows.append(row)
|
| 158 |
+
print(f" ok {ver:42s} class={row['class']:12s} "
|
| 159 |
+
f"n_axes={row['n_axes']:3d} target_mse={row['target']}")
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f" SKIP {ver:42s} {type(e).__name__}: {e}")
|
| 162 |
+
|
| 163 |
+
if not cb_rows:
|
| 164 |
+
print(" no batteries loaded β check BATTERIES / network")
|
| 165 |
+
return {}
|
| 166 |
+
|
| 167 |
+
rows = collect_signatures(cb_rows, enabled=enabled)
|
| 168 |
+
|
| 169 |
+
# per-battery signature table
|
| 170 |
+
names = [s[0] for s in SIGNAL_SPECS if (enabled is None or s[0] in enabled)]
|
| 171 |
+
print("\nββ per-battery contribution values ββ")
|
| 172 |
+
header = "battery".ljust(42) + "".join(f"{n[:11]:>13s}" for n in names)
|
| 173 |
+
print(header)
|
| 174 |
+
for r in rows:
|
| 175 |
+
line = r["id"][:40].ljust(42)
|
| 176 |
+
for n in names:
|
| 177 |
+
v = r["values"].get(n, float("nan"))
|
| 178 |
+
line += f"{v:>13.4f}"
|
| 179 |
+
print(line)
|
| 180 |
+
|
| 181 |
+
# ablation ranking
|
| 182 |
+
table = ablation_table(rows)
|
| 183 |
+
n_target = max((s["n_target"] for s in table.values()), default=0)
|
| 184 |
+
classes_present = sorted({r.get("class") for r in rows if r.get("class") is not None})
|
| 185 |
+
print(f"\nββ contribution informativeness ββ")
|
| 186 |
+
print(f" cv = scale-free spread | |rho| = |Spearman| w/ recon MSE (n={n_target}, detects BROKEN)")
|
| 187 |
+
print(f" eta2 = variance explained by class (detects CLASS SEPARATION) | classes: {classes_present}")
|
| 188 |
+
def _key(it):
|
| 189 |
+
e = it[1]["eta2_by_class"]
|
| 190 |
+
rho = it[1]["abs_spearman_with_target"]
|
| 191 |
+
return (-(e if e == e else -1), -(rho if rho == rho else -1))
|
| 192 |
+
for name, stats in sorted(table.items(), key=_key):
|
| 193 |
+
rho = stats["abs_spearman_with_target"]; rho_s = f"{rho:.3f}" if rho == rho else " -- "
|
| 194 |
+
eta = stats["eta2_by_class"]; eta_s = f"{eta:.3f}" if eta == eta else " -- "
|
| 195 |
+
cv = stats["cv"]; cv_s = f"{cv:6.2f}" if cv == cv else " -- "
|
| 196 |
+
print(f" {name:26s} eta2={eta_s} |rho|={rho_s} cv={cv_s} n={stats['n_valid']}")
|
| 197 |
+
|
| 198 |
+
# per-class means for the strongest class separators
|
| 199 |
+
top = sorted(table.items(), key=_key)[:4]
|
| 200 |
+
print(f"\nββ per-class means (top {len(top)} class-separating signals) ββ")
|
| 201 |
+
hdr = "class".ljust(16) + "".join(f"{n[:11]:>13s}" for n, _ in top)
|
| 202 |
+
print(hdr)
|
| 203 |
+
for c in classes_present:
|
| 204 |
+
line = str(c).ljust(16)
|
| 205 |
+
for _, stats in top:
|
| 206 |
+
mv = stats["class_means"].get(str(c))
|
| 207 |
+
line += (f"{mv:>13.3f}" if mv is not None else f"{'--':>13s}")
|
| 208 |
+
print(line)
|
| 209 |
+
return {"rows": rows, "table": table}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
# If BATTERIES is left at the lone default, ablate the whole discovered zoo.
|
| 214 |
+
bats = BATTERIES if len(BATTERIES) > 1 else discover_batteries()
|
| 215 |
+
run_ablation(bats)
|