feather-a10g-large-runtime / overlay /scripts /experiment_ablation.py
icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
#!/usr/bin/env python3
"""Ablation study: Engram vs SSM vs SDR sparsity contributions.
Computes effective rank deltas across all components β€” fully vectorized SVD.
"""
import json, os
from pathlib import Path
import torch
import numpy as np
OUT_DIR = Path(__file__).resolve().parents[1] / "docs"
CKPT_PATH = Path.home() / ".cache" / "autoresearch" / "latest.pt"
print("[ABLATION] Loading checkpoint...")
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
md = ckpt["model_state_dict"]
cfg = ckpt.get("config", {})
N_LAYER = cfg.get("n_layer", 20)
D_MODEL = cfg.get("d_model", 160)
def eff_rank(w: torch.Tensor) -> float:
u, s, vh = torch.linalg.svd(w.float(), full_matrices=False)
s_np = s.numpy()
s_norm = s_np / (s_np.sum() + 1e-30)
entropy = -np.sum(s_norm * np.log(s_norm + 1e-30))
return float(np.exp(entropy))
def rank_90(w: torch.Tensor) -> int:
u, s, vh = torch.linalg.svd(w.float(), full_matrices=False)
cumvar = np.cumsum(s.numpy()**2) / np.sum(s.numpy()**2)
return int(np.searchsorted(cumvar, 0.90) + 1)
# ── 1. Baseline: all encoder layers ────────────────────────
print(f"[ABLATION] Computing {N_LAYER} encoder layers...")
enc_weights = torch.stack([md[f"blocks.{i}.in_proj.weight"].float() for i in range(N_LAYER)])
baseline_ranks = [eff_rank(enc_weights[i]) for i in range(N_LAYER)]
baseline_r90 = [rank_90(enc_weights[i]) for i in range(N_LAYER)]
# ── 2. Engram memory ────────────────────────────────────────
engram_mem = md["engram.memory"].float() # (16384, 160)
engram_er = eff_rank(engram_mem)
engram_r90 = rank_90(engram_mem)
engram_gate_w = md["engram.gate.weight"].float()
engram_gate_b = md["engram.gate.bias"].float()
# ── 3. SDR projection: delta_u @ delta_v ────────────────────
sdr_u = md["sdr_semantic.delta_u"].float() # (65536, 32)
sdr_v = md["sdr_semantic.delta_v"].float() # (32, 16384)
sdr_proj = sdr_u @ sdr_v # (65536, 16384)
sdr_proj_er = eff_rank(sdr_proj)
sdr_u_er = eff_rank(sdr_u)
sdr_v_er = eff_rank(sdr_v)
# ── 4. SSM conditioning (in_proj singular value ratio) ──────
ssm_cn = []
for i in range(N_LAYER):
w = md[f"blocks.{i}.in_proj.weight"].float()
s = torch.linalg.svd(w, full_matrices=False)[1].numpy()
ssm_cn.append(float(s.max() / (s.min() + 1e-10)))
# ── 5. SDR retina sparsity ─────────────────────────────────
retina = md.get("_retina_indices", None)
retina_info = {}
if retina is not None:
n_tok, n_active = retina.shape
retina_info = {"n_tokens": int(n_tok), "n_active_per_token": int(n_active), "sparsity_pct": float(n_active / retina.shape[1] * 100)}
results = {
"baseline_encoder": {
"mean_effective_rank": float(np.mean(baseline_ranks)),
"median_effective_rank": float(np.median(baseline_ranks)),
"min_effective_rank": float(np.min(baseline_ranks)),
"max_effective_rank": float(np.max(baseline_ranks)),
"std_effective_rank": float(np.std(baseline_ranks)),
"mean_rank_90pct": float(np.mean(baseline_r90)),
"layer_ranks": baseline_ranks,
"layer_ranks_90": baseline_r90,
"d_model": D_MODEL,
"intrinsic_dim_vs_model_pct": float(np.median(baseline_ranks) / D_MODEL * 100),
},
"engram": {
"shape": list(engram_mem.shape),
"effective_rank": engram_er,
"rank_90pct": engram_r90,
"memory_utilization_pct": float(engram_er / min(engram_mem.shape) * 100),
"gate_weight_mean": float(engram_gate_w.mean().item()),
"gate_bias": float(engram_gate_b.item()),
},
"sdr": {
"projection_shape": [sdr_u.shape[0], sdr_v.shape[1]],
"projection_effective_rank": sdr_proj_er,
"delta_u_effective_rank": sdr_u_er,
"delta_v_effective_rank": sdr_v_er,
"projection_utilization_pct": float(sdr_proj_er / min(sdr_u.shape[0], sdr_v.shape[1]) * 100),
**retina_info,
},
"ssm": {
"condition_numbers": ssm_cn,
"mean_condition_number": float(np.mean(ssm_cn)),
"median_condition_number": float(np.median(ssm_cn)),
"max_condition_number": float(np.max(ssm_cn)),
},
"interpretation": {
"engram_memory": "Engram learns ~N_mem compressed patterns. Low eff_rank = few distinct attractor states.",
"sdr_projection": "Projects 65K vocab β†’ 16K SDR bits. eff_rank measures how many independent concept directions survive.",
"ssm_conditioning": "In-proj singular ratio. High = dynamics input-sensitive; low = dynamics input-suppressed.",
"intrinsic_dim": f"If median eff_rank << {D_MODEL}, the model actively uses far fewer dimensions than available β€” strong manifold compression.",
}
}
Path(OUT_DIR / "results_ablation.json").write_text(json.dumps(results, indent=2, default=str))
print(f"[ABLATION] Saved {OUT_DIR / 'results_ablation.json'}")
print(f"[ABLATION] Mean eff_rank: {np.mean(baseline_ranks):.2f} / d_model={D_MODEL}")
print(f"[ABLATION] Engram eff_rank: {engram_er:.2f} / min({engram_mem.shape[0]},{engram_mem.shape[1]})")
print(f"[ABLATION] SDR proj eff_rank: {sdr_proj_er:.2f} / min({sdr_u.shape[0]},{sdr_v.shape[1]})")
print(f"[ABLATION] Mean SSM condition number: {np.mean(ssm_cn):.1f}")