geolip-svae-transformer / codebook_contributions.py
AbstractPhil's picture
Update codebook_contributions.py
1f250c1 verified
"""
codebook_contributions.py β€” contribution signals for the omega-phase classifier.
The shipped ``_classify_omega_phase`` (geolip_svae.inference.train_codebook)
consumes ONLY H0 connectivity. Everything else the topology probe measures β€”
H1 loops, H2 voids, local intrinsic dimension, percolation scale, the
deviation envelope, the antipodal/sign structure β€” is computed and discarded.
This module turns each discarded quantity into a named, independently-toggleable
CONTRIBUTION SIGNAL so they can be ablated across training runs ("run N trains,
test each contribution as a whole"). Every signal is mathematically aligned to
the system's omega/aleph rules, which were read out of the utilizers:
* PROJECTIVE metric. Axes are sign-canonicalized (`canon`); the distance is
d(a,b) = arccos(|<a,b>|) in [0, Ο€/2] on ℝP^(D-1) β€” NOT the raw S^(D-1)
angle the stock probe uses. Loops/voids are recomputed in this metric.
* UNIFORM baseline. uniform_projective_angle(D) is the rigid packing
reference; structure is deviation FROM it.
* dev_critical(D) = 0.02·√D is the envelope half-width (rigidity_barrier).
Deviation signals are reported in dev_critical units so |x|>1 == out of
envelope, identical to the architectural constraint.
* ALEPH address. The Β± antipodal bit (canon / -0.9 collapse) is the sign
half of the address; its realization is a first-class signal.
Persistence (H1/H2) needs `ripser` (pip install ripser). Without it those
signals report NaN and are flagged ripser_required so ablation can exclude them
cleanly rather than silently corrupt a run.
Torch-free (numpy + scipy + optional ripser): auditable, runs anywhere.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
try:
from ripser import ripser as _ripser
HAVE_RIPSER = True
except ImportError:
HAVE_RIPSER = False
_ripser = None
HALF_PI = math.pi / 2.0 # max projective angle; the natural normalizer
# ── system math, reproduced in numpy (matches geolip_svae utilizers) ──
def canon_np(v: np.ndarray) -> np.ndarray:
"""Sign-canonicalize rows onto the projective representative: first
non-zero coordinate positive. Mirrors model_transformer.canon."""
v = np.asarray(v, dtype=np.float64)
out = v.copy()
for i in range(v.shape[0]):
nz = np.nonzero(np.abs(v[i]) > 1e-6)[0]
if len(nz) and v[i, nz[0]] < 0:
out[i] = -v[i]
return out
def _unit(axes: np.ndarray) -> np.ndarray:
axes = np.asarray(axes, dtype=np.float64)
return axes / np.linalg.norm(axes, axis=1, keepdims=True).clip(min=1e-12)
def projective_distance(axes_unit: np.ndarray) -> np.ndarray:
"""[n,n] projective angular distance arccos(|cos|) in [0, Ο€/2]. THE metric
on ℝP^(D-1) β€” antipodes are the same point, matching the aleph convention."""
cos = np.clip(axes_unit @ axes_unit.T, -1.0, 1.0)
d = np.arccos(np.abs(cos))
np.fill_diagonal(d, 0.0)
return d
_UMEAN: Dict[int, float] = {}
def uniform_projective_angle(D: int, n: int = 4096, seed: int = 0) -> float:
"""Mean pairwise projective angle of uniform directions on ℝP^(D-1).
Reproduces geolip_svae.inference.codebook.uniform_projective_angle."""
if D in _UMEAN:
return _UMEAN[D]
rng = np.random.default_rng(seed)
pts = rng.standard_normal((n, D))
pts = canon_np(pts / np.linalg.norm(pts, axis=1, keepdims=True).clip(min=1e-12))
cos = np.clip(pts @ pts.T, -1.0, 1.0)
iu = np.triu_indices(n, k=1)
_UMEAN[D] = float(np.arccos(np.abs(cos[iu])).mean())
return _UMEAN[D]
def dev_critical(D: int, coeff: float = 0.02) -> float:
"""Envelope half-width 0.02·√D β€” the rigidity_barrier scale."""
return coeff * math.sqrt(D)
# ── projective topology probes (numpy/scipy/ripser) ──
def _percolation(d: np.ndarray, theta_grid: Sequence[float],
frac: float = 0.5) -> Tuple[Optional[float], Dict[float, float]]:
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
n = d.shape[0]
largest_at: Dict[float, float] = {}
perc: Optional[float] = None
for th in theta_grid:
adj = (d <= th) & (d > 0)
_, labels = connected_components(csr_matrix(adj.astype(np.int8)), directed=False)
largest = np.bincount(labels).max() / n
largest_at[float(th)] = float(largest)
if perc is None and largest >= frac:
perc = float(th)
return perc, largest_at
def _largest_component_frac(d: np.ndarray, theta: float) -> float:
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
adj = (d <= theta) & (d > 0)
_, labels = connected_components(csr_matrix(adj.astype(np.int8)), directed=False)
return float(np.bincount(labels).max() / d.shape[0])
def _local_pr_dim(axes_unit: np.ndarray, d: np.ndarray, k: int = 10) -> np.ndarray:
"""Per-axis participation-ratio dimension (Σλ)²/Σλ² of the k-NN offset
cloud. β†’ 1 if neighbors lie on a curve, β†’ D if they fill the tangent space."""
n = axes_unit.shape[0]
k = min(k, max(1, n - 1))
nn = np.argsort(d, axis=1)
pr = np.zeros(n)
for i in range(n):
off = axes_unit[nn[i, 1:k + 1]] - axes_unit[i]
off -= off.mean(0)
s = np.linalg.svd(off, full_matrices=False, compute_uv=False)
lam = (s ** 2) / k
s1, s2 = lam.sum(), (lam ** 2).sum()
pr[i] = (s1 ** 2) / s2 if s2 > 0 else 0.0
return pr
def _persistence(d: np.ndarray, maxdim: int = 2,
thresh: float = HALF_PI) -> Optional[Dict[str, np.ndarray]]:
"""ripser on the projective distance matrix. Returns {'H0','H1','H2'} ->
finite [birth,death] arrays (radians). None if ripser unavailable."""
if not HAVE_RIPSER:
return None
dgms = _ripser(d, distance_matrix=True, maxdim=maxdim, thresh=float(thresh))['dgms']
out: Dict[str, np.ndarray] = {}
for h, dgm in enumerate(dgms):
finite = dgm[np.isfinite(dgm[:, 1])] if len(dgm) else np.zeros((0, 2))
out[f'H{h}'] = finite
return out
def _persist_summary(finite: np.ndarray) -> Tuple[int, float, float, float]:
"""(betti, total_persistence_frac, max_persistence_frac, persistence_entropy)
for one finite diagram, persistences normalized by HALF_PI."""
if finite is None or len(finite) == 0:
return 0, 0.0, 0.0, 0.0
pers = (finite[:, 1] - finite[:, 0]).clip(min=0.0)
total = float(pers.sum() / HALF_PI)
mx = float(pers.max() / HALF_PI)
p = pers / pers.sum() if pers.sum() > 0 else np.ones_like(pers) / len(pers)
ent = float(-(p * np.log(p.clip(min=1e-12))).sum() / math.log(len(pers))) if len(pers) > 1 else 0.0
return int(len(finite)), total, mx, ent
# ── contribution signal ──
@dataclass
class ContributionSignal:
name: str
value: float
units: str
formula: str
rule: str # the omega/aleph rule it preserves
utilization: str # how it's meant to be consumed
ripser_required: bool = False
enabled: bool = True
# Registry: every signal the H0-only classifier currently ignores.
# Each entry is (name, units, formula, rule, utilization, ripser_required).
SIGNAL_SPECS: List[Tuple[str, str, str, str, str, bool]] = [
# geometry / deviation envelope
("proj_deviation", "rad",
"mean acos|cos| over axes βˆ’ uniform_projective_angle(D)",
"projective metric vs uniform ℝP^(D-1) baseline",
"phase feature: signed distance of the frame from rigid packing", False),
("deviation_envelopes", "dev_crit",
"proj_deviation / (0.02·√D)",
"dev_critical envelope; |x|>1 == out of envelope (rigidity_barrier)",
"phase gate: in/out of the architectural envelope, scale-free across D", False),
("angular_iqr", "rad",
"p75 βˆ’ p25 of pairwise projective angles",
"projective metric",
"spread tightness of the axis cloud (degeneracy vs dispersion)", False),
# connectivity (projective)
("percolation_ratio", "ratio",
"ΞΈ_percolation(proj) / uniform_projective_angle(D)",
"connection scale measured against the uniform baseline",
"how tight before a giant component forms; <1 == clusters below uniform", False),
("giant_frac_at_uniform", "frac",
"largest connected component / n at ΞΈ = uniform_projective_angle(D)",
"graph threshold pinned to the uniform baseline (not arbitrary degrees)",
"coalescence at the natural scale; complements H0 finite/infinite", False),
# local geometry
("local_dim_ratio", "ratio",
"median participation-ratio dim of k-NN offsets / D",
"PCA on neighbor offsets in the projective tangent",
"how fully axes locally span ℝP^(D-1); β†’0 curve-like, β†’1 space-filling", False),
# loops H1 (projective)
("betti1", "per_axis",
"(# finite H1 features at ΞΈ=Ο€/2) / n_axes",
"persistent homology on the projective distance; intensive (per-axis)",
"loop density β€” cyclic structure per axis, comparable across codebook sizes", True),
("h1_total_persistence", "frac/axis",
"Ξ£(deathβˆ’birth) over finite H1 / (Ο€/2) / n_axes",
"persistence in projective angular units; intensive (per-axis)",
"loop-mass density", True),
("h1_max_persistence", "frac",
"max(deathβˆ’birth) over finite H1 / (Ο€/2)",
"persistence in projective angular units",
"strength of the single dominant loop", True),
("h1_persistence_entropy", "nat/log",
"normalized Shannon entropy of H1 persistences",
"standard persistence-entropy summary",
"regular (high) vs single-dominant (low) loop spectrum", True),
# voids H2 (projective)
("betti2", "per_axis",
"(# finite H2 features at ΞΈ=Ο€/2) / n_axes",
"persistent homology on the projective distance; intensive (per-axis)",
"void density β€” cavities per axis; high == noise-like, low == structured", True),
("h2_total_persistence", "frac/axis",
"Ξ£(deathβˆ’birth) over finite H2 / (Ο€/2) / n_axes",
"persistence in projective angular units; intensive (per-axis)",
"void-mass density β€” the unused signal, now scale-free", True),
("h2_max_persistence", "frac",
"max(deathβˆ’birth) over finite H2 / (Ο€/2)",
"persistence in projective angular units",
"strength of the single dominant void", True),
("h2_persistence_entropy", "nat/log",
"normalized Shannon entropy of H2 persistences",
"standard persistence-entropy summary",
"regular vs single-dominant void spectrum", True),
# aleph / sign structure
("pairing_fraction", "frac",
"n_pairs / (n_pairs + n_unpaired)",
"antipodal collapse at cos<βˆ’0.9 == realization of the aleph Β± bit",
"how strongly the frame realizes sign-addressable structure", False),
]
def compute_contributions(
axes: np.ndarray,
D: Optional[int] = None,
*,
n_pairs: Optional[int] = None,
n_unpaired: Optional[int] = None,
enabled: Optional[Sequence[str]] = None,
knn_k: int = 10,
percolation_grid_deg: Sequence[float] = (0.5, 1, 2, 4, 6, 8, 10, 14, 20, 30, 45, 60, 90),
) -> Dict[str, ContributionSignal]:
"""Compute every contribution signal from a codebook's axes (and optional
pair metadata). `enabled` restricts to a subset for ablation; None = all.
Returns name -> ContributionSignal (value NaN if ripser missing / N/A)."""
axes_unit = _unit(canon_np(axes))
n, Dax = axes_unit.shape
D = int(D or Dax)
want = set(enabled) if enabled is not None else {s[0] for s in SIGNAL_SPECS}
d = projective_distance(axes_unit)
iu = np.triu_indices(n, k=1)
off = d[iu]
uniform = uniform_projective_angle(D)
crit = dev_critical(D)
vals: Dict[str, float] = {}
# geometry
mean_proj = float(off.mean()) if len(off) else float('nan')
dev = mean_proj - uniform
vals["proj_deviation"] = dev
vals["deviation_envelopes"] = dev / crit if crit > 0 else float('nan')
vals["angular_iqr"] = float(np.percentile(off, 75) - np.percentile(off, 25)) if len(off) else float('nan')
# connectivity
if {"percolation_ratio"} & want:
perc, _ = _percolation(d, [math.radians(t) for t in percolation_grid_deg])
vals["percolation_ratio"] = (perc / uniform) if (perc and uniform > 0) else float('nan')
if {"giant_frac_at_uniform"} & want:
vals["giant_frac_at_uniform"] = _largest_component_frac(d, uniform)
# local geometry
if "local_dim_ratio" in want:
pr = _local_pr_dim(axes_unit, d, k=knn_k)
vals["local_dim_ratio"] = float(np.median(pr) / D) if D > 0 else float('nan')
# persistence (H1 loops, H2 voids)
need_persist = bool(want & {
"betti1", "h1_total_persistence", "h1_max_persistence", "h1_persistence_entropy",
"betti2", "h2_total_persistence", "h2_max_persistence", "h2_persistence_entropy"})
if need_persist:
pers = _persistence(d, maxdim=2, thresh=HALF_PI)
if pers is None:
for k_ in ("betti1", "h1_total_persistence", "h1_max_persistence", "h1_persistence_entropy",
"betti2", "h2_total_persistence", "h2_max_persistence", "h2_persistence_entropy"):
vals[k_] = float('nan')
else:
b1, h1t, h1m, h1e = _persist_summary(pers.get("H1"))
b2, h2t, h2m, h2e = _persist_summary(pers.get("H2"))
inv_n = 1.0 / max(1, n) # intensive: per-axis density, comparable across codebook sizes
vals.update(betti1=float(b1) * inv_n, h1_total_persistence=h1t * inv_n, h1_max_persistence=h1m,
h1_persistence_entropy=h1e, betti2=float(b2) * inv_n, h2_total_persistence=h2t * inv_n,
h2_max_persistence=h2m, h2_persistence_entropy=h2e)
# aleph / sign structure
if "pairing_fraction" in want:
if n_pairs is not None and n_unpaired is not None and (n_pairs + n_unpaired) > 0:
vals["pairing_fraction"] = float(n_pairs / (n_pairs + n_unpaired))
else:
vals["pairing_fraction"] = float('nan')
out: Dict[str, ContributionSignal] = {}
for name, units, formula, rule, util, rip in SIGNAL_SPECS:
if name not in want:
continue
out[name] = ContributionSignal(
name=name, value=float(vals.get(name, float('nan'))), units=units,
formula=formula, rule=rule, utilization=util, ripser_required=rip,
enabled=True)
return out
# ── omega signature: base H0 phase + the new contributions + flags ──
def omega_signature(
axes: np.ndarray, D: Optional[int] = None, *,
n_pairs: Optional[int] = None, n_unpaired: Optional[int] = None,
enabled: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
"""Full signature: contribution values + principled boolean flags derived
from the system rules. The flags are the testable hypotheses each
contribution encodes; toggle `enabled` to ablate which feed the phase."""
c = compute_contributions(axes, D, n_pairs=n_pairs, n_unpaired=n_unpaired, enabled=enabled)
def v(name):
return c[name].value if name in c else float('nan')
flags = {
# geometry: out of the rigidity envelope (|dev| > 1 dev_critical)
"out_of_envelope": (abs(v("deviation_envelopes")) > 1.0) if "deviation_envelopes" in c else None,
# loops present and dominant (a loop spanning >25% of the projective range)
"has_persistent_loops": (v("betti1") > 0 and v("h1_max_persistence") > 0.25) if "betti1" in c else None,
# voids present and dominant β€” the headline unused signal
"has_persistent_voids": (v("betti2") > 0 and v("h2_max_persistence") > 0.25) if "betti2" in c else None,
# space-filling locally vs curve/pair-like
"space_filling": (v("local_dim_ratio") > 0.5) if "local_dim_ratio" in c else None,
# sign-addressable frame (most rows collapsed to antipodal pairs)
"sign_addressable": (v("pairing_fraction") > 0.5) if "pairing_fraction" in c else None,
}
return {
"n_axes": int(_unit(axes).shape[0]),
"D": int(D or np.asarray(axes).shape[1]),
"contributions": {k: asdict(s) for k, s in c.items()},
"flags": {k: val for k, val in flags.items() if val is not None},
"ripser_available": HAVE_RIPSER,
}
# ── omega_phase_v2: two orthogonal axes the 39-battery ablation established ──
#
# The cross-dimension Ξ·Β² ranking (local_dim/giant_frac/angular_iqr on top) and
# the within-D=4 ranking (those collapse; the VOIDS rise) showed the signal is
# not one taxonomy but TWO independent axes:
# regime β€” cross-dimension geometry (β‰ˆ dimension Γ— training-health).
# Dominates across D; collapses within fixed D.
# void_character β€” within-dimension SUBSTRATE signal carried by the voids.
# Symbolic vocabularies are void-structured; continuous/image
# void-sparse; near-random clouds void-saturated. The geometry
# cannot see this (it's flat within D); the voids can.
# dispersion β€” deviation vs the dev_critical envelope (survives both: s-class
# under-dispersed, image over-dispersed).
#
# Thresholds are EMPIRICAL from the zoo (2026-05) and tunable; this is descriptive
# telemetry, not a loss and not a proof. Refit as the zoo grows.
OMEGA_V2_THRESHOLDS: Dict[str, float] = {
"iqr_collapsed": 0.12, # angular_iqr below + fragmented -> angularly collapsed
"giant_fragmented": 0.50, # giant_frac below -> doesn't percolate at uniform
"localdim_concentrated": 0.30, # local_dim_ratio below -> low intrinsic dim (high-D)
"localdim_spacefilling": 0.40, # local_dim_ratio above -> space-filling
"dev_under": -0.60, # deviation_envelopes below -> under-dispersed (s-class)
"dev_over": 1.00, # above -> over-dispersed
"void_sparse": 0.18, # betti2/axis below -> void-sparse (continuous/image)
"void_saturated": 1.50, # betti2/axis above -> void-saturated (noise-like)
"void_entropy": 0.50, # h2 entropy above (with mid density) -> structured voids
}
OMEGA_V2_LABELS: Dict[str, Tuple[str, ...]] = {
"regime": ("collapsed_fragmented", "concentrated", "space_filling", "transitional"),
"dispersion": ("under_dispersed", "in_envelope", "over_dispersed", "unknown"),
"void_character": ("void_sparse", "void_structured", "void_saturated", "void_mixed", "unknown"),
}
def _ok(x: Any) -> bool:
return x is not None and x == x # not None, not NaN
def label_phase(values: Dict[str, float],
thresholds: Optional[Dict[str, float]] = None) -> Dict[str, str]:
"""Pure labeling logic over a contribution value-dict (ripser-free testable).
Returns {regime, dispersion, void_character}."""
th = {**OMEGA_V2_THRESHOLDS, **(thresholds or {})}
iqr = values.get("angular_iqr"); giant = values.get("giant_frac_at_uniform")
ldim = values.get("local_dim_ratio"); dev = values.get("deviation_envelopes")
b2d = values.get("betti2"); h2e = values.get("h2_persistence_entropy")
# regime (cross-dimension geometry)
if _ok(iqr) and _ok(giant) and iqr < th["iqr_collapsed"] and giant < th["giant_fragmented"]:
regime = "collapsed_fragmented"
elif _ok(ldim) and ldim < th["localdim_concentrated"]:
regime = "concentrated"
elif _ok(ldim) and ldim >= th["localdim_spacefilling"]:
regime = "space_filling"
else:
regime = "transitional"
# dispersion (deviation envelope)
if not _ok(dev):
dispersion = "unknown"
elif dev < th["dev_under"]:
dispersion = "under_dispersed"
elif dev > th["dev_over"]:
dispersion = "over_dispersed"
else:
dispersion = "in_envelope"
# void character (within-dimension substrate signal β€” the headline)
if not _ok(b2d):
void_character = "unknown"
elif b2d > th["void_saturated"]:
void_character = "void_saturated"
elif b2d >= th["void_sparse"] and _ok(h2e) and h2e > th["void_entropy"]:
void_character = "void_structured"
elif b2d < th["void_sparse"]:
void_character = "void_sparse"
else:
void_character = "void_mixed"
return {"regime": regime, "dispersion": dispersion, "void_character": void_character}
def omega_phase_v2(axes: np.ndarray, D: Optional[int] = None, *,
n_pairs: Optional[int] = None, n_unpaired: Optional[int] = None,
thresholds: Optional[Dict[str, float]] = None) -> Dict[str, Any]:
"""Composite codebook phase from the ablation-surviving contributions.
Three orthogonal axes: regime (cross-D geometry), dispersion (dev envelope),
void_character (within-D substrate signal β€” symbolic vs continuous vs noise).
Returns the labels plus the driver values behind them. Needs ripser for the
void axis; without it void_character == 'unknown'."""
c = compute_contributions(axes, D, n_pairs=n_pairs, n_unpaired=n_unpaired)
values = {k: s.value for k, s in c.items()}
labels = label_phase(values, thresholds)
drivers = {k: values.get(k) for k in (
"angular_iqr", "giant_frac_at_uniform", "local_dim_ratio",
"deviation_envelopes", "betti2", "h2_persistence_entropy", "h2_max_persistence")}
return {**labels,
"drivers": drivers,
"n_axes": int(_unit(canon_np(axes)).shape[0]),
"D": int(D or np.asarray(axes).shape[1]),
"ripser_available": HAVE_RIPSER}
# ── ablation harness: test each contribution across multiple trains ──
def collect_signatures(codebooks: Sequence[Dict[str, Any]],
enabled: Optional[Sequence[str]] = None) -> List[Dict[str, Any]]:
"""codebooks: list of {'id', 'axes', 'D', optional 'n_pairs','n_unpaired',
optional 'target': scalar downstream metric (recon mse / MAR score / label)}.
Returns one signature row per codebook for ablation."""
rows = []
for cb in codebooks:
sig = omega_signature(cb["axes"], cb.get("D"), n_pairs=cb.get("n_pairs"),
n_unpaired=cb.get("n_unpaired"), enabled=enabled)
rows.append({"id": cb.get("id", f"cb{len(rows)}"),
"target": cb.get("target"),
"class": cb.get("class"),
"n_axes": sig["n_axes"],
"values": {k: s["value"] for k, s in sig["contributions"].items()},
"flags": sig["flags"]})
return rows
def _eta_squared(col: np.ndarray, classes: List[Any]) -> Tuple[float, Dict[str, float]]:
"""One-way ANOVA Ξ·Β² = SS_between/SS_total: fraction of a signal's variance
explained by model class. The right tool for 'does this separate substrates'
(recon-corr can't, since class is nominal). Also returns per-class means.
NOTE: biased upward when groups are tiny β€” trust the well-populated classes."""
groups: Dict[str, List[float]] = {}
for x, c in zip(col, classes):
if np.isfinite(x) and c is not None:
groups.setdefault(str(c), []).append(float(x))
if len(groups) < 2:
return float('nan'), {}
allv = np.concatenate([np.array(v) for v in groups.values()])
if len(allv) < 3:
return float('nan'), {}
m = allv.mean()
ss_tot = float(((allv - m) ** 2).sum())
ss_btw = float(sum(len(v) * (np.mean(v) - m) ** 2 for v in groups.values()))
eta2 = ss_btw / ss_tot if ss_tot > 1e-12 else float('nan')
means = {c: float(np.mean(v)) for c, v in groups.items()}
return eta2, means
def ablation_table(rows: List[Dict[str, Any]]) -> Dict[str, Dict[str, float]]:
"""Per-contribution informativeness across the collected trains:
* std β€” raw spread across runs
* cv β€” std/|mean|, scale-free spread (rank when no target)
* |rho| β€” |Spearman| with target (recon MSE); detects BROKEN codebooks
* eta2_by_class β€” variance explained by model class; detects CLASS SEPARATION
(the structural question; uncorrelated with recon by design)
A signal earns a classifier slot if it separates classes (eta2) and/or tracks
the target (|rho|). Computed over the available subset, not all-or-nothing."""
names = sorted({k for r in rows for k in r["values"]})
targets = np.array([r["target"] if r.get("target") is not None else np.nan
for r in rows], dtype=np.float64)
classes = [r.get("class") for r in rows]
try:
from scipy.stats import spearmanr
except Exception:
spearmanr = None
table: Dict[str, Dict[str, float]] = {}
for nm in names:
col = np.array([r["values"].get(nm, np.nan) for r in rows], dtype=np.float64)
valid = np.isfinite(col)
std = float(np.nanstd(col)) if valid.any() else float('nan')
mean = float(np.nanmean(col)) if valid.any() else float('nan')
cv = float(std / abs(mean)) if (mean == mean and abs(mean) > 1e-12) else float('nan')
rho = float('nan')
mask = valid & np.isfinite(targets)
if spearmanr is not None and mask.sum() >= 3 and np.std(col[mask]) > 1e-12 \
and np.std(targets[mask]) > 1e-12:
r_ = spearmanr(col[mask], targets[mask]).correlation
rho = float(abs(r_)) if r_ == r_ else float('nan')
eta2, class_means = _eta_squared(col, classes)
table[nm] = {"std": std, "cv": cv, "abs_spearman_with_target": rho,
"eta2_by_class": eta2, "class_means": class_means,
"n_valid": int(valid.sum()), "n_target": int(mask.sum())}
return table
__all__ = [
"HAVE_RIPSER", "canon_np", "projective_distance", "uniform_projective_angle",
"dev_critical", "ContributionSignal", "SIGNAL_SPECS", "compute_contributions",
"omega_signature", "collect_signatures", "ablation_table", "_eta_squared",
"omega_phase_v2", "label_phase", "OMEGA_V2_THRESHOLDS", "OMEGA_V2_LABELS",
]
if __name__ == "__main__":
# Smoke + sanity on synthetic ℝP^(D-1) clouds: the signals must RESPOND to
# known structure before we trust them on real codebooks.
rng = np.random.default_rng(0)
D = 4
def report(tag, axes, **kw):
sig = omega_signature(axes, D, **kw)
print(f"\n[{tag}] n_axes={sig['n_axes']} ripser={sig['ripser_available']}")
for k, s in sig["contributions"].items():
print(f" {k:24s} = {s['value']:+.4f} [{s['units']}]")
print(" flags:", sig["flags"])
uni = canon_np(rng.standard_normal((64, D))) # uniform packing
report("uniform", uni)
base = rng.standard_normal((1, D)) # tight cluster (degenerate)
clus = canon_np(base + 0.05 * rng.standard_normal((64, D)))
report("tight_cluster", clus)
t = np.linspace(0, 2 * math.pi, 64, endpoint=False) # a loop in a 2-plane
loop = np.zeros((64, D)); loop[:, 0] = np.cos(t); loop[:, 1] = np.sin(t)
report("ring_H1", canon_np(loop), n_pairs=0, n_unpaired=64)
"""
battery_ablation.py β€” test contribution signals across batteries.
For each battery: load it frozen, extract its projective codebook, compute the
contribution signals (codebook_contributions), and pull its recon MSE as the
target. Then rank every signal by:
* std across batteries β€” does it vary at all, or is it a dead signal?
* |corr| with recon MSE β€” does it track downstream quality?
This is the "run N trains, test each contribution as a whole" pass: each
battery is one data point; the ablation table says which contributions earn a
slot in the omega-phase classifier before we hardwire any of them.
Cell workflow: paste codebook_contributions cell first, then this. Edit
BATTERIES to your set (β‰₯3 needed for correlation). `pip install ripser` for the
H1/H2 void signals; without it they self-exclude as NaN.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import numpy as np
# cell-tolerant: from the codebook_contributions cell (or installed)
try:
from codebook_contributions import (
collect_signatures, ablation_table, SIGNAL_SPECS, HAVE_RIPSER,
)
except ModuleNotFoundError:
pass
# ── edit this to your battery set ───────────────────────────────────
BATTERIES: List[str] = [
"h2_linear_tiny_imagenet_64",
# add your other battery folder names here, e.g.:
# "h2_linear_imagenet_128",
# "byte_trigram_proto_64_patch_2_v1",
# "v40_freckles_noise", "v50_fresnel_64", ...
]
REPO_ID = "AbstractPhil/geolip-SVAE"
def discover_batteries(repo_id: str = REPO_ID) -> List[str]:
"""List every battery folder in the repo that has a checkpoints/best.pt.
Saves you maintaining BATTERIES by hand β€” `run_ablation(discover_batteries())`
ablates over the whole zoo (mixed classes/D are fine; signals are D-normalized)."""
from huggingface_hub import HfApi
files = HfApi().list_repo_files(repo_id)
vers = sorted({f.split("/")[0] for f in files if f.endswith("/checkpoints/best.pt")})
print(f" discovered {len(vers)} batteries in {repo_id}")
return vers
def _load_model_safe(ver: str, device: str, repo_id: str):
"""load_model, with a fallback for torch.compile checkpoints whose state-dict
keys carry an '_orig_mod.' prefix. On that specific failure: re-download, strip
the prefix (and backfill config from final_report.json the way load_model would,
since checkpoint_path loads skip hf_version backfill), re-save, re-enter via
checkpoint_path so all of load_model's construction logic is reused."""
from geolip_svae.inference.loading import load_model
try:
return load_model(hf_version=ver, device=device, repo_id=repo_id)
except RuntimeError as e:
if "_orig_mod." not in str(e):
raise
import torch, os, tempfile, json
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=repo_id, filename=f"{ver}/checkpoints/best.pt",
repo_type="model")
ckpt = torch.load(path, map_location="cpu", weights_only=False)
pref = "_orig_mod."
ckpt["model_state_dict"] = {
(k[len(pref):] if k.startswith(pref) else k): v
for k, v in ckpt["model_state_dict"].items()
}
# mirror load_model's final_report backfill into the temp config
cfg0 = dict(ckpt.get("config", {}))
backfillable = ("n_heads", "smooth_mid", "linear_readout",
"svd_mode", "match_params", "channels")
if any(k not in cfg0 for k in backfillable):
try:
rp = hf_hub_download(repo_id=repo_id, filename=f"{ver}/final_report.json",
repo_type="model")
rc = json.load(open(rp)).get("config", {})
for k in backfillable:
if k not in cfg0 and rc.get(k) is not None:
cfg0[k] = rc[k]
ckpt["config"] = cfg0
except Exception:
pass
tmp = os.path.join(tempfile.gettempdir(), f"{ver.replace('/', '_')}_stripped.pt")
torch.save(ckpt, tmp)
model, cfg = load_model(checkpoint_path=tmp, device=device, repo_id=repo_id)
print(f" (recovered {ver}: stripped _orig_mod. torch.compile prefix)")
return model, cfg
def extract_row(ver: str, device: str) -> Dict[str, Any]:
"""Load a frozen battery, extract its codebook, return an ablation row
{id, axes, D, n_pairs, n_unpaired, target=recon_mse, class}."""
from geolip_svae.inference.calibration import make_calibration
from geolip_svae.inference.codebook import extract_codebook
from geolip_svae.inference.train_codebook import (
infer_class_from_cfg, DEFAULT_CALIBRATIONS,
)
import torch
model, cfg = _load_model_safe(ver, device, REPO_ID)
cls = infer_class_from_cfg(cfg)
cal = DEFAULT_CALIBRATIONS.get(cls, DEFAULT_CALIBRATIONS["unknown"])
size = cfg.get("img_size") or cal["size"]
calib = make_calibration(cal["name"], n=cal["n"], size=size)
if not isinstance(calib, torch.Tensor):
calib = torch.as_tensor(calib)
ch = int(cfg.get("channels", 3)) # match model input channels
if calib.shape[1] != ch:
if ch < calib.shape[1]:
calib = calib[:, :ch]
else:
reps = (ch + calib.shape[1] - 1) // calib.shape[1]
calib = calib.repeat(1, reps, 1, 1)[:, :ch]
cb = extract_codebook(model, calib.to(device), model_id=ver,
model_class=cls, calibration_name=cal["name"])
axes = cb.axes.detach().cpu().numpy()
n_pairs = getattr(cb.metadata, "n_pairs", None)
n_unpaired = getattr(cb.metadata, "n_unpaired", None)
if n_pairs is None:
n_pairs, n_unpaired = len(cb.pairs), len(cb.unpaired)
return {
"id": ver,
"class": cls,
"axes": axes,
"D": int(cfg.get("D") or axes.shape[1]),
"n_pairs": int(n_pairs),
"n_unpaired": int(n_unpaired),
"target": cfg.get("_test_mse"), # recon MSE (None if absent)
"n_axes": int(axes.shape[0]),
}
def run_ablation(batteries: Optional[List[str]] = None, device: Optional[str] = None,
enabled=None) -> Dict[str, Any]:
"""Extract every battery's codebook, compute signatures, rank contributions."""
import torch
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
batteries = batteries or BATTERIES
print(f"[battery_ablation] {len(batteries)} batteries on {device} | ripser={HAVE_RIPSER}")
cb_rows: List[Dict[str, Any]] = []
for ver in batteries:
try:
row = extract_row(ver, device)
cb_rows.append(row)
print(f" ok {ver:42s} class={row['class']:12s} "
f"n_axes={row['n_axes']:3d} target_mse={row['target']}")
except Exception as e:
print(f" SKIP {ver:42s} {type(e).__name__}: {e}")
if not cb_rows:
print(" no batteries loaded β€” check BATTERIES / network")
return {}
rows = collect_signatures(cb_rows, enabled=enabled)
# per-battery signature table
names = [s[0] for s in SIGNAL_SPECS if (enabled is None or s[0] in enabled)]
print("\n── per-battery contribution values ──")
header = "battery".ljust(42) + "".join(f"{n[:11]:>13s}" for n in names)
print(header)
for r in rows:
line = r["id"][:40].ljust(42)
for n in names:
v = r["values"].get(n, float("nan"))
line += f"{v:>13.4f}"
print(line)
# ablation ranking
table = ablation_table(rows)
n_target = max((s["n_target"] for s in table.values()), default=0)
classes_present = sorted({r.get("class") for r in rows if r.get("class") is not None})
print(f"\n── contribution informativeness ──")
print(f" cv = scale-free spread | |rho| = |Spearman| w/ recon MSE (n={n_target}, detects BROKEN)")
print(f" eta2 = variance explained by class (detects CLASS SEPARATION) | classes: {classes_present}")
def _key(it):
e = it[1]["eta2_by_class"]
rho = it[1]["abs_spearman_with_target"]
return (-(e if e == e else -1), -(rho if rho == rho else -1))
for name, stats in sorted(table.items(), key=_key):
rho = stats["abs_spearman_with_target"]; rho_s = f"{rho:.3f}" if rho == rho else " -- "
eta = stats["eta2_by_class"]; eta_s = f"{eta:.3f}" if eta == eta else " -- "
cv = stats["cv"]; cv_s = f"{cv:6.2f}" if cv == cv else " -- "
print(f" {name:26s} eta2={eta_s} |rho|={rho_s} cv={cv_s} n={stats['n_valid']}")
# per-class means for the strongest class separators
top = sorted(table.items(), key=_key)[:4]
print(f"\n── per-class means (top {len(top)} class-separating signals) ──")
hdr = "class".ljust(16) + "".join(f"{n[:11]:>13s}" for n, _ in top)
print(hdr)
for c in classes_present:
line = str(c).ljust(16)
for _, stats in top:
mv = stats["class_means"].get(str(c))
line += (f"{mv:>13.3f}" if mv is not None else f"{'--':>13s}")
print(line)
return {"rows": rows, "table": table}
if __name__ == "__main__":
# If BATTERIES is left at the lone default, ablate the whole discovered zoo.
bats = BATTERIES if len(BATTERIES) > 1 else discover_batteries()
run_ablation(bats)