icarus112's picture
Upload folder using huggingface_hub
22741d9 verified
"""
Gradient flow probe for PostSemClawModel.
READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
step an optimizer. Runs one forward + backward and reports, per-parameter:
name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm
Severity classification at the bottom:
BLOCKER β€” requires_grad=True but p.grad is None (disconnected from graph)
WARNING β€” grad present but literally zero (ops cancel, wd_init, etc.)
WARNING β€” requires_grad=True but param missing from every optimizer group
OK β€” everything else
Usage:
.venv/bin/python -u scripts/grad_probe.py
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
# Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
# resolve when we run from any cwd). Probe is intentionally a thin wrapper.
HERE = Path(__file__).resolve().parent
ROOT = HERE.parent
sys.path.insert(0, str(ROOT))
# Small model config to keep the probe fast (still exercises every component).
# K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
os.environ.setdefault("HYDRA_D_MODEL", "256")
os.environ.setdefault("HYDRA_N_LAYER", "4")
os.environ.setdefault("HYDRA_MTP_K", "4")
import torch # noqa: E402
from train import PostSemClawModel, PostSemClawConfig # noqa: E402
def main() -> int:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
return 2
cfg = PostSemClawConfig(
sequence_len=64,
vocab_size=8192,
n_layer=int(os.environ["HYDRA_N_LAYER"]),
d_model=int(os.environ["HYDRA_D_MODEL"]),
d_state=64,
headdim=32,
n_heads=8,
expand=2,
engram_n_columns=1024,
engram_key_dim=64,
engram_layer_idx=1,
sdr_n_bits=16384,
sdr_target_active=327,
sdr_delta_rank=32,
sdr_som_warmup=500,
sdr_som_interval=100,
htm_n_columns=2048,
htm_cells_per_column=32,
mtp_k=int(os.environ["HYDRA_MTP_K"]),
mtp_weight_decay=0.5,
)
print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")
torch.manual_seed(0)
model = PostSemClawModel(cfg).to(device)
model.init_weights()
model.train()
# ---- Enumerate params & optimizer group assignment ----
all_params = list(model.named_parameters())
print(f"[probe] total named parameters: {len(all_params)}")
# Build optimizer to check group coverage (no step, no zero_grad).
opt = model.setup_optimizer()
grouped_ids: set[int] = set()
for group in opt.param_groups:
for p in group["params"]:
grouped_ids.add(id(p))
unique_param_ids = {id(p) for _, p in all_params}
missing_from_opt = unique_param_ids - grouped_ids
print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
if missing_from_opt:
print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")
# Tied weight check.
tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")
# ---- One forward + backward under bf16 autocast ----
B, T = 1, 64
idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
tgt = torch.roll(idx, -1, dims=1)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(idx, targets=tgt)
print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
loss.backward()
torch.cuda.synchronize()
# ---- Report ----
blockers: list[str] = []
zero_grads: list[str] = []
unexpected_frozen: list[str] = []
not_in_opt: list[str] = []
rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []
for name, p in all_params:
grad_is_none = p.grad is None
if p.requires_grad and grad_is_none:
blockers.append(name)
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
p.requires_grad, True, float("nan"), float("nan")))
continue
if not p.requires_grad:
unexpected_frozen.append(name)
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
False, True, float("nan"), float("nan")))
continue
g = p.grad.detach().float()
abs_mean = float(g.abs().mean().item())
norm = float(g.norm().item())
if abs_mean == 0.0 and norm == 0.0:
zero_grads.append(name)
if id(p) not in grouped_ids:
not_in_opt.append(name)
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
p.requires_grad, False, abs_mean, norm))
# Pretty table
print("\n[probe] per-parameter grad table:")
print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}")
for name, shape, dtype, rg, none, mean, norm in rows:
shape_s = "x".join(str(s) for s in shape)
rg_s = "Y" if rg else "N"
none_s = "Y" if none else "N"
if none:
mean_s, norm_s = " nan ", " nan "
else:
mean_s = f"{mean:>10.3e}"
norm_s = f"{norm:>10.3e}"
print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}")
# Identity checks
print("\n[probe] identity checks:")
print(f" id(wte.weight) = {id(model.wte.weight)}")
print(f" id(lm_head.weight) = {id(model.lm_head.weight)}")
print(f" same Python object = {model.wte.weight is model.lm_head.weight}")
print(f" same storage ptr = {tied}")
# Engram memory inspection
print(f"\n[probe] engram.memory is nn.Parameter: "
f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}")
if model.engram.memory.grad is None:
print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
else:
g = model.engram.memory.grad.detach().float()
print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")
# Stash flag sanity: _last_sdr should be uint8, no graph
last = getattr(model, "_last_sdr", None)
if last is not None:
print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
else:
print("\n[probe] model._last_sdr is None (fwd didn't stash β€” ok if path changed)")
# Summary
print("\n[probe] ============ SUMMARY ============")
print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}")
for n in blockers:
print(f" - {n}")
print(f" WARNINGS (grad is literally zero): {len(zero_grads)}")
for n in zero_grads:
print(f" - {n}")
print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}")
for n in unexpected_frozen:
print(f" - {n}")
print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}")
for n in not_in_opt:
print(f" - {n}")
return 0
if __name__ == "__main__":
sys.exit(main())