File size: 7,405 Bytes
22741d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
Gradient flow probe for PostSemClawModel.
READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
step an optimizer. Runs one forward + backward and reports, per-parameter:
name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm
Severity classification at the bottom:
BLOCKER β requires_grad=True but p.grad is None (disconnected from graph)
WARNING β grad present but literally zero (ops cancel, wd_init, etc.)
WARNING β requires_grad=True but param missing from every optimizer group
OK β everything else
Usage:
.venv/bin/python -u scripts/grad_probe.py
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
# Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
# resolve when we run from any cwd). Probe is intentionally a thin wrapper.
HERE = Path(__file__).resolve().parent
ROOT = HERE.parent
sys.path.insert(0, str(ROOT))
# Small model config to keep the probe fast (still exercises every component).
# K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
os.environ.setdefault("HYDRA_D_MODEL", "256")
os.environ.setdefault("HYDRA_N_LAYER", "4")
os.environ.setdefault("HYDRA_MTP_K", "4")
import torch # noqa: E402
from train import PostSemClawModel, PostSemClawConfig # noqa: E402
def main() -> int:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
return 2
cfg = PostSemClawConfig(
sequence_len=64,
vocab_size=8192,
n_layer=int(os.environ["HYDRA_N_LAYER"]),
d_model=int(os.environ["HYDRA_D_MODEL"]),
d_state=64,
headdim=32,
n_heads=8,
expand=2,
engram_n_columns=1024,
engram_key_dim=64,
engram_layer_idx=1,
sdr_n_bits=16384,
sdr_target_active=327,
sdr_delta_rank=32,
sdr_som_warmup=500,
sdr_som_interval=100,
htm_n_columns=2048,
htm_cells_per_column=32,
mtp_k=int(os.environ["HYDRA_MTP_K"]),
mtp_weight_decay=0.5,
)
print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")
torch.manual_seed(0)
model = PostSemClawModel(cfg).to(device)
model.init_weights()
model.train()
# ---- Enumerate params & optimizer group assignment ----
all_params = list(model.named_parameters())
print(f"[probe] total named parameters: {len(all_params)}")
# Build optimizer to check group coverage (no step, no zero_grad).
opt = model.setup_optimizer()
grouped_ids: set[int] = set()
for group in opt.param_groups:
for p in group["params"]:
grouped_ids.add(id(p))
unique_param_ids = {id(p) for _, p in all_params}
missing_from_opt = unique_param_ids - grouped_ids
print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
if missing_from_opt:
print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")
# Tied weight check.
tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")
# ---- One forward + backward under bf16 autocast ----
B, T = 1, 64
idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
tgt = torch.roll(idx, -1, dims=1)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(idx, targets=tgt)
print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
loss.backward()
torch.cuda.synchronize()
# ---- Report ----
blockers: list[str] = []
zero_grads: list[str] = []
unexpected_frozen: list[str] = []
not_in_opt: list[str] = []
rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []
for name, p in all_params:
grad_is_none = p.grad is None
if p.requires_grad and grad_is_none:
blockers.append(name)
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
p.requires_grad, True, float("nan"), float("nan")))
continue
if not p.requires_grad:
unexpected_frozen.append(name)
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
False, True, float("nan"), float("nan")))
continue
g = p.grad.detach().float()
abs_mean = float(g.abs().mean().item())
norm = float(g.norm().item())
if abs_mean == 0.0 and norm == 0.0:
zero_grads.append(name)
if id(p) not in grouped_ids:
not_in_opt.append(name)
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
p.requires_grad, False, abs_mean, norm))
# Pretty table
print("\n[probe] per-parameter grad table:")
print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}")
for name, shape, dtype, rg, none, mean, norm in rows:
shape_s = "x".join(str(s) for s in shape)
rg_s = "Y" if rg else "N"
none_s = "Y" if none else "N"
if none:
mean_s, norm_s = " nan ", " nan "
else:
mean_s = f"{mean:>10.3e}"
norm_s = f"{norm:>10.3e}"
print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}")
# Identity checks
print("\n[probe] identity checks:")
print(f" id(wte.weight) = {id(model.wte.weight)}")
print(f" id(lm_head.weight) = {id(model.lm_head.weight)}")
print(f" same Python object = {model.wte.weight is model.lm_head.weight}")
print(f" same storage ptr = {tied}")
# Engram memory inspection
print(f"\n[probe] engram.memory is nn.Parameter: "
f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}")
if model.engram.memory.grad is None:
print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
else:
g = model.engram.memory.grad.detach().float()
print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")
# Stash flag sanity: _last_sdr should be uint8, no graph
last = getattr(model, "_last_sdr", None)
if last is not None:
print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
else:
print("\n[probe] model._last_sdr is None (fwd didn't stash β ok if path changed)")
# Summary
print("\n[probe] ============ SUMMARY ============")
print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}")
for n in blockers:
print(f" - {n}")
print(f" WARNINGS (grad is literally zero): {len(zero_grads)}")
for n in zero_grads:
print(f" - {n}")
print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}")
for n in unexpected_frozen:
print(f" - {n}")
print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}")
for n in not_in_opt:
print(f" - {n}")
return 0
if __name__ == "__main__":
sys.exit(main())
|