| """ |
| Gradient flow probe for PostSemClawModel. |
| |
| READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT |
| step an optimizer. Runs one forward + backward and reports, per-parameter: |
| |
| name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm |
| |
| Severity classification at the bottom: |
| BLOCKER β requires_grad=True but p.grad is None (disconnected from graph) |
| WARNING β grad present but literally zero (ops cancel, wd_init, etc.) |
| WARNING β requires_grad=True but param missing from every optimizer group |
| OK β everything else |
| |
| Usage: |
| .venv/bin/python -u scripts/grad_probe.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
| |
| HERE = Path(__file__).resolve().parent |
| ROOT = HERE.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| |
| |
| os.environ.setdefault("HYDRA_D_MODEL", "256") |
| os.environ.setdefault("HYDRA_N_LAYER", "4") |
| os.environ.setdefault("HYDRA_MTP_K", "4") |
|
|
| import torch |
|
|
| from train import PostSemClawModel, PostSemClawConfig |
|
|
|
|
| def main() -> int: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if device != "cuda": |
| print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).") |
| return 2 |
|
|
| cfg = PostSemClawConfig( |
| sequence_len=64, |
| vocab_size=8192, |
| n_layer=int(os.environ["HYDRA_N_LAYER"]), |
| d_model=int(os.environ["HYDRA_D_MODEL"]), |
| d_state=64, |
| headdim=32, |
| n_heads=8, |
| expand=2, |
| engram_n_columns=1024, |
| engram_key_dim=64, |
| engram_layer_idx=1, |
| sdr_n_bits=16384, |
| sdr_target_active=327, |
| sdr_delta_rank=32, |
| sdr_som_warmup=500, |
| sdr_som_interval=100, |
| htm_n_columns=2048, |
| htm_cells_per_column=32, |
| mtp_k=int(os.environ["HYDRA_MTP_K"]), |
| mtp_weight_decay=0.5, |
| ) |
|
|
| print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} " |
| f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}") |
|
|
| torch.manual_seed(0) |
| model = PostSemClawModel(cfg).to(device) |
| model.init_weights() |
| model.train() |
|
|
| |
| all_params = list(model.named_parameters()) |
| print(f"[probe] total named parameters: {len(all_params)}") |
|
|
| |
| opt = model.setup_optimizer() |
| grouped_ids: set[int] = set() |
| for group in opt.param_groups: |
| for p in group["params"]: |
| grouped_ids.add(id(p)) |
| unique_param_ids = {id(p) for _, p in all_params} |
| missing_from_opt = unique_param_ids - grouped_ids |
| print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}") |
| if missing_from_opt: |
| print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups") |
|
|
| |
| tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr() |
| print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}") |
|
|
| |
| B, T = 1, 64 |
| idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device) |
| tgt = torch.roll(idx, -1, dims=1) |
|
|
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
| loss = model(idx, targets=tgt) |
| print(f"[probe] fwd loss = {float(loss.detach()):.4f}") |
| loss.backward() |
| torch.cuda.synchronize() |
|
|
| |
| blockers: list[str] = [] |
| zero_grads: list[str] = [] |
| unexpected_frozen: list[str] = [] |
| not_in_opt: list[str] = [] |
| rows: list[tuple[str, tuple, str, bool, bool, float, float]] = [] |
|
|
| for name, p in all_params: |
| grad_is_none = p.grad is None |
| if p.requires_grad and grad_is_none: |
| blockers.append(name) |
| rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), |
| p.requires_grad, True, float("nan"), float("nan"))) |
| continue |
| if not p.requires_grad: |
| unexpected_frozen.append(name) |
| rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), |
| False, True, float("nan"), float("nan"))) |
| continue |
| g = p.grad.detach().float() |
| abs_mean = float(g.abs().mean().item()) |
| norm = float(g.norm().item()) |
| if abs_mean == 0.0 and norm == 0.0: |
| zero_grads.append(name) |
| if id(p) not in grouped_ids: |
| not_in_opt.append(name) |
| rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), |
| p.requires_grad, False, abs_mean, norm)) |
|
|
| |
| print("\n[probe] per-parameter grad table:") |
| print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}") |
| for name, shape, dtype, rg, none, mean, norm in rows: |
| shape_s = "x".join(str(s) for s in shape) |
| rg_s = "Y" if rg else "N" |
| none_s = "Y" if none else "N" |
| if none: |
| mean_s, norm_s = " nan ", " nan " |
| else: |
| mean_s = f"{mean:>10.3e}" |
| norm_s = f"{norm:>10.3e}" |
| print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}") |
|
|
| |
| print("\n[probe] identity checks:") |
| print(f" id(wte.weight) = {id(model.wte.weight)}") |
| print(f" id(lm_head.weight) = {id(model.lm_head.weight)}") |
| print(f" same Python object = {model.wte.weight is model.lm_head.weight}") |
| print(f" same storage ptr = {tied}") |
|
|
| |
| print(f"\n[probe] engram.memory is nn.Parameter: " |
| f"{isinstance(model.engram.memory, torch.nn.Parameter)}") |
| print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}") |
| if model.engram.memory.grad is None: |
| print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())") |
| else: |
| g = model.engram.memory.grad.detach().float() |
| print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}") |
|
|
| |
| last = getattr(model, "_last_sdr", None) |
| if last is not None: |
| print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}") |
| else: |
| print("\n[probe] model._last_sdr is None (fwd didn't stash β ok if path changed)") |
|
|
| |
| print("\n[probe] ============ SUMMARY ============") |
| print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}") |
| for n in blockers: |
| print(f" - {n}") |
| print(f" WARNINGS (grad is literally zero): {len(zero_grads)}") |
| for n in zero_grads: |
| print(f" - {n}") |
| print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}") |
| for n in unexpected_frozen: |
| print(f" - {n}") |
| print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}") |
| for n in not_in_opt: |
| print(f" - {n}") |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|