AbstractPhil commited on
Commit
108d2be
Β·
verified Β·
1 Parent(s): 258553d

Create codebook_contributions.py

Browse files
Files changed (1) hide show
  1. codebook_contributions.py +215 -0
codebook_contributions.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ battery_ablation.py β€” test contribution signals across batteries.
3
+
4
+ For each battery: load it frozen, extract its projective codebook, compute the
5
+ contribution signals (codebook_contributions), and pull its recon MSE as the
6
+ target. Then rank every signal by:
7
+ * std across batteries β€” does it vary at all, or is it a dead signal?
8
+ * |corr| with recon MSE β€” does it track downstream quality?
9
+
10
+ This is the "run N trains, test each contribution as a whole" pass: each
11
+ battery is one data point; the ablation table says which contributions earn a
12
+ slot in the omega-phase classifier before we hardwire any of them.
13
+
14
+ Cell workflow: paste codebook_contributions cell first, then this. Edit
15
+ BATTERIES to your set (β‰₯3 needed for correlation). `pip install ripser` for the
16
+ H1/H2 void signals; without it they self-exclude as NaN.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ from typing import Any, Dict, List, Optional
21
+
22
+ import numpy as np
23
+
24
+ # cell-tolerant: from the codebook_contributions cell (or installed)
25
+ try:
26
+ from codebook_contributions import (
27
+ collect_signatures, ablation_table, SIGNAL_SPECS, HAVE_RIPSER,
28
+ )
29
+ except ModuleNotFoundError:
30
+ pass
31
+
32
+
33
+ # ── edit this to your battery set ───────────────────────────────────
34
+ BATTERIES: List[str] = [
35
+ "h2_linear_tiny_imagenet_64",
36
+ # add your other battery folder names here, e.g.:
37
+ # "h2_linear_imagenet_128",
38
+ # "byte_trigram_proto_64_patch_2_v1",
39
+ # "v40_freckles_noise", "v50_fresnel_64", ...
40
+ ]
41
+ REPO_ID = "AbstractPhil/geolip-SVAE"
42
+
43
+
44
+ def discover_batteries(repo_id: str = REPO_ID) -> List[str]:
45
+ """List every battery folder in the repo that has a checkpoints/best.pt.
46
+ Saves you maintaining BATTERIES by hand β€” `run_ablation(discover_batteries())`
47
+ ablates over the whole zoo (mixed classes/D are fine; signals are D-normalized)."""
48
+ from huggingface_hub import HfApi
49
+ files = HfApi().list_repo_files(repo_id)
50
+ vers = sorted({f.split("/")[0] for f in files if f.endswith("/checkpoints/best.pt")})
51
+ print(f" discovered {len(vers)} batteries in {repo_id}")
52
+ return vers
53
+
54
+
55
+ def _load_model_safe(ver: str, device: str, repo_id: str):
56
+ """load_model, with a fallback for torch.compile checkpoints whose state-dict
57
+ keys carry an '_orig_mod.' prefix. On that specific failure: re-download, strip
58
+ the prefix (and backfill config from final_report.json the way load_model would,
59
+ since checkpoint_path loads skip hf_version backfill), re-save, re-enter via
60
+ checkpoint_path so all of load_model's construction logic is reused."""
61
+ from geolip_svae.inference.loading import load_model
62
+ try:
63
+ return load_model(hf_version=ver, device=device, repo_id=repo_id)
64
+ except RuntimeError as e:
65
+ if "_orig_mod." not in str(e):
66
+ raise
67
+ import torch, os, tempfile, json
68
+ from huggingface_hub import hf_hub_download
69
+ path = hf_hub_download(repo_id=repo_id, filename=f"{ver}/checkpoints/best.pt",
70
+ repo_type="model")
71
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
72
+ pref = "_orig_mod."
73
+ ckpt["model_state_dict"] = {
74
+ (k[len(pref):] if k.startswith(pref) else k): v
75
+ for k, v in ckpt["model_state_dict"].items()
76
+ }
77
+ # mirror load_model's final_report backfill into the temp config
78
+ cfg0 = dict(ckpt.get("config", {}))
79
+ backfillable = ("n_heads", "smooth_mid", "linear_readout",
80
+ "svd_mode", "match_params", "channels")
81
+ if any(k not in cfg0 for k in backfillable):
82
+ try:
83
+ rp = hf_hub_download(repo_id=repo_id, filename=f"{ver}/final_report.json",
84
+ repo_type="model")
85
+ rc = json.load(open(rp)).get("config", {})
86
+ for k in backfillable:
87
+ if k not in cfg0 and rc.get(k) is not None:
88
+ cfg0[k] = rc[k]
89
+ ckpt["config"] = cfg0
90
+ except Exception:
91
+ pass
92
+ tmp = os.path.join(tempfile.gettempdir(), f"{ver.replace('/', '_')}_stripped.pt")
93
+ torch.save(ckpt, tmp)
94
+ model, cfg = load_model(checkpoint_path=tmp, device=device, repo_id=repo_id)
95
+ print(f" (recovered {ver}: stripped _orig_mod. torch.compile prefix)")
96
+ return model, cfg
97
+
98
+
99
+ def extract_row(ver: str, device: str) -> Dict[str, Any]:
100
+ """Load a frozen battery, extract its codebook, return an ablation row
101
+ {id, axes, D, n_pairs, n_unpaired, target=recon_mse, class}."""
102
+ from geolip_svae.inference.calibration import make_calibration
103
+ from geolip_svae.inference.codebook import extract_codebook
104
+ from geolip_svae.inference.train_codebook import (
105
+ infer_class_from_cfg, DEFAULT_CALIBRATIONS,
106
+ )
107
+ import torch
108
+
109
+ model, cfg = _load_model_safe(ver, device, REPO_ID)
110
+ cls = infer_class_from_cfg(cfg)
111
+ cal = DEFAULT_CALIBRATIONS.get(cls, DEFAULT_CALIBRATIONS["unknown"])
112
+ size = cfg.get("img_size") or cal["size"]
113
+
114
+ calib = make_calibration(cal["name"], n=cal["n"], size=size)
115
+ if not isinstance(calib, torch.Tensor):
116
+ calib = torch.as_tensor(calib)
117
+ ch = int(cfg.get("channels", 3)) # match model input channels
118
+ if calib.shape[1] != ch:
119
+ if ch < calib.shape[1]:
120
+ calib = calib[:, :ch]
121
+ else:
122
+ reps = (ch + calib.shape[1] - 1) // calib.shape[1]
123
+ calib = calib.repeat(1, reps, 1, 1)[:, :ch]
124
+
125
+ cb = extract_codebook(model, calib.to(device), model_id=ver,
126
+ model_class=cls, calibration_name=cal["name"])
127
+ axes = cb.axes.detach().cpu().numpy()
128
+ n_pairs = getattr(cb.metadata, "n_pairs", None)
129
+ n_unpaired = getattr(cb.metadata, "n_unpaired", None)
130
+ if n_pairs is None:
131
+ n_pairs, n_unpaired = len(cb.pairs), len(cb.unpaired)
132
+
133
+ return {
134
+ "id": ver,
135
+ "class": cls,
136
+ "axes": axes,
137
+ "D": int(cfg.get("D") or axes.shape[1]),
138
+ "n_pairs": int(n_pairs),
139
+ "n_unpaired": int(n_unpaired),
140
+ "target": cfg.get("_test_mse"), # recon MSE (None if absent)
141
+ "n_axes": int(axes.shape[0]),
142
+ }
143
+
144
+
145
+ def run_ablation(batteries: Optional[List[str]] = None, device: Optional[str] = None,
146
+ enabled=None) -> Dict[str, Any]:
147
+ """Extract every battery's codebook, compute signatures, rank contributions."""
148
+ import torch
149
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
150
+ batteries = batteries or BATTERIES
151
+ print(f"[battery_ablation] {len(batteries)} batteries on {device} | ripser={HAVE_RIPSER}")
152
+
153
+ cb_rows: List[Dict[str, Any]] = []
154
+ for ver in batteries:
155
+ try:
156
+ row = extract_row(ver, device)
157
+ cb_rows.append(row)
158
+ print(f" ok {ver:42s} class={row['class']:12s} "
159
+ f"n_axes={row['n_axes']:3d} target_mse={row['target']}")
160
+ except Exception as e:
161
+ print(f" SKIP {ver:42s} {type(e).__name__}: {e}")
162
+
163
+ if not cb_rows:
164
+ print(" no batteries loaded β€” check BATTERIES / network")
165
+ return {}
166
+
167
+ rows = collect_signatures(cb_rows, enabled=enabled)
168
+
169
+ # per-battery signature table
170
+ names = [s[0] for s in SIGNAL_SPECS if (enabled is None or s[0] in enabled)]
171
+ print("\n── per-battery contribution values ──")
172
+ header = "battery".ljust(42) + "".join(f"{n[:11]:>13s}" for n in names)
173
+ print(header)
174
+ for r in rows:
175
+ line = r["id"][:40].ljust(42)
176
+ for n in names:
177
+ v = r["values"].get(n, float("nan"))
178
+ line += f"{v:>13.4f}"
179
+ print(line)
180
+
181
+ # ablation ranking
182
+ table = ablation_table(rows)
183
+ n_target = max((s["n_target"] for s in table.values()), default=0)
184
+ classes_present = sorted({r.get("class") for r in rows if r.get("class") is not None})
185
+ print(f"\n── contribution informativeness ──")
186
+ print(f" cv = scale-free spread | |rho| = |Spearman| w/ recon MSE (n={n_target}, detects BROKEN)")
187
+ print(f" eta2 = variance explained by class (detects CLASS SEPARATION) | classes: {classes_present}")
188
+ def _key(it):
189
+ e = it[1]["eta2_by_class"]
190
+ rho = it[1]["abs_spearman_with_target"]
191
+ return (-(e if e == e else -1), -(rho if rho == rho else -1))
192
+ for name, stats in sorted(table.items(), key=_key):
193
+ rho = stats["abs_spearman_with_target"]; rho_s = f"{rho:.3f}" if rho == rho else " -- "
194
+ eta = stats["eta2_by_class"]; eta_s = f"{eta:.3f}" if eta == eta else " -- "
195
+ cv = stats["cv"]; cv_s = f"{cv:6.2f}" if cv == cv else " -- "
196
+ print(f" {name:26s} eta2={eta_s} |rho|={rho_s} cv={cv_s} n={stats['n_valid']}")
197
+
198
+ # per-class means for the strongest class separators
199
+ top = sorted(table.items(), key=_key)[:4]
200
+ print(f"\n── per-class means (top {len(top)} class-separating signals) ──")
201
+ hdr = "class".ljust(16) + "".join(f"{n[:11]:>13s}" for n, _ in top)
202
+ print(hdr)
203
+ for c in classes_present:
204
+ line = str(c).ljust(16)
205
+ for _, stats in top:
206
+ mv = stats["class_means"].get(str(c))
207
+ line += (f"{mv:>13.3f}" if mv is not None else f"{'--':>13s}")
208
+ print(line)
209
+ return {"rows": rows, "table": table}
210
+
211
+
212
+ if __name__ == "__main__":
213
+ # If BATTERIES is left at the lone default, ablate the whole discovered zoo.
214
+ bats = BATTERIES if len(BATTERIES) > 1 else discover_batteries()
215
+ run_ablation(bats)