""" Gradient flow probe for PostSemClawModel. READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT step an optimizer. Runs one forward + backward and reports, per-parameter: name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm Severity classification at the bottom: BLOCKER — requires_grad=True but p.grad is None (disconnected from graph) WARNING — grad present but literally zero (ops cancel, wd_init, etc.) WARNING — requires_grad=True but param missing from every optimizer group OK — everything else Usage: .venv/bin/python -u scripts/grad_probe.py """ from __future__ import annotations import os import sys from pathlib import Path # Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare` # resolve when we run from any cwd). Probe is intentionally a thin wrapper. HERE = Path(__file__).resolve().parent ROOT = HERE.parent sys.path.insert(0, str(ROOT)) # Small model config to keep the probe fast (still exercises every component). # K=4 MTP (default), d_model=256 (default), n_layer=4 (default). os.environ.setdefault("HYDRA_D_MODEL", "256") os.environ.setdefault("HYDRA_N_LAYER", "4") os.environ.setdefault("HYDRA_MTP_K", "4") import torch # noqa: E402 from train import PostSemClawModel, PostSemClawConfig # noqa: E402 def main() -> int: device = "cuda" if torch.cuda.is_available() else "cpu" if device != "cuda": print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).") return 2 cfg = PostSemClawConfig( sequence_len=64, vocab_size=8192, n_layer=int(os.environ["HYDRA_N_LAYER"]), d_model=int(os.environ["HYDRA_D_MODEL"]), d_state=64, headdim=32, n_heads=8, expand=2, engram_n_columns=1024, engram_key_dim=64, engram_layer_idx=1, sdr_n_bits=16384, sdr_target_active=327, sdr_delta_rank=32, sdr_som_warmup=500, sdr_som_interval=100, htm_n_columns=2048, htm_cells_per_column=32, mtp_k=int(os.environ["HYDRA_MTP_K"]), mtp_weight_decay=0.5, ) print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} " f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}") torch.manual_seed(0) model = PostSemClawModel(cfg).to(device) model.init_weights() model.train() # ---- Enumerate params & optimizer group assignment ---- all_params = list(model.named_parameters()) print(f"[probe] total named parameters: {len(all_params)}") # Build optimizer to check group coverage (no step, no zero_grad). opt = model.setup_optimizer() grouped_ids: set[int] = set() for group in opt.param_groups: for p in group["params"]: grouped_ids.add(id(p)) unique_param_ids = {id(p) for _, p in all_params} missing_from_opt = unique_param_ids - grouped_ids print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}") if missing_from_opt: print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups") # Tied weight check. tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr() print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}") # ---- One forward + backward under bf16 autocast ---- B, T = 1, 64 idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device) tgt = torch.roll(idx, -1, dims=1) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(idx, targets=tgt) print(f"[probe] fwd loss = {float(loss.detach()):.4f}") loss.backward() torch.cuda.synchronize() # ---- Report ---- blockers: list[str] = [] zero_grads: list[str] = [] unexpected_frozen: list[str] = [] not_in_opt: list[str] = [] rows: list[tuple[str, tuple, str, bool, bool, float, float]] = [] for name, p in all_params: grad_is_none = p.grad is None if p.requires_grad and grad_is_none: blockers.append(name) rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), p.requires_grad, True, float("nan"), float("nan"))) continue if not p.requires_grad: unexpected_frozen.append(name) rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), False, True, float("nan"), float("nan"))) continue g = p.grad.detach().float() abs_mean = float(g.abs().mean().item()) norm = float(g.norm().item()) if abs_mean == 0.0 and norm == 0.0: zero_grads.append(name) if id(p) not in grouped_ids: not_in_opt.append(name) rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), p.requires_grad, False, abs_mean, norm)) # Pretty table print("\n[probe] per-parameter grad table:") print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}") for name, shape, dtype, rg, none, mean, norm in rows: shape_s = "x".join(str(s) for s in shape) rg_s = "Y" if rg else "N" none_s = "Y" if none else "N" if none: mean_s, norm_s = " nan ", " nan " else: mean_s = f"{mean:>10.3e}" norm_s = f"{norm:>10.3e}" print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}") # Identity checks print("\n[probe] identity checks:") print(f" id(wte.weight) = {id(model.wte.weight)}") print(f" id(lm_head.weight) = {id(model.lm_head.weight)}") print(f" same Python object = {model.wte.weight is model.lm_head.weight}") print(f" same storage ptr = {tied}") # Engram memory inspection print(f"\n[probe] engram.memory is nn.Parameter: " f"{isinstance(model.engram.memory, torch.nn.Parameter)}") print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}") if model.engram.memory.grad is None: print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())") else: g = model.engram.memory.grad.detach().float() print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}") # Stash flag sanity: _last_sdr should be uint8, no graph last = getattr(model, "_last_sdr", None) if last is not None: print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}") else: print("\n[probe] model._last_sdr is None (fwd didn't stash — ok if path changed)") # Summary print("\n[probe] ============ SUMMARY ============") print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}") for n in blockers: print(f" - {n}") print(f" WARNINGS (grad is literally zero): {len(zero_grads)}") for n in zero_grads: print(f" - {n}") print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}") for n in unexpected_frozen: print(f" - {n}") print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}") for n in not_in_opt: print(f" - {n}") return 0 if __name__ == "__main__": sys.exit(main())