Spaces:
Runtime error
Runtime error
| """ | |
| Gradient flow probe for PostSemClawModel. | |
| READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT | |
| step an optimizer. Runs one forward + backward and reports, per-parameter: | |
| name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm | |
| Severity classification at the bottom: | |
| BLOCKER — requires_grad=True but p.grad is None (disconnected from graph) | |
| WARNING — grad present but literally zero (ops cancel, wd_init, etc.) | |
| WARNING — requires_grad=True but param missing from every optimizer group | |
| OK — everything else | |
| Usage: | |
| .venv/bin/python -u scripts/grad_probe.py | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare` | |
| # resolve when we run from any cwd). Probe is intentionally a thin wrapper. | |
| HERE = Path(__file__).resolve().parent | |
| ROOT = HERE.parent | |
| sys.path.insert(0, str(ROOT)) | |
| # Small model config to keep the probe fast (still exercises every component). | |
| # K=4 MTP (default), d_model=256 (default), n_layer=4 (default). | |
| os.environ.setdefault("HYDRA_D_MODEL", "256") | |
| os.environ.setdefault("HYDRA_N_LAYER", "4") | |
| os.environ.setdefault("HYDRA_MTP_K", "4") | |
| import torch # noqa: E402 | |
| from train import PostSemClawModel, PostSemClawConfig # noqa: E402 | |
| def main() -> int: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if device != "cuda": | |
| print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).") | |
| return 2 | |
| cfg = PostSemClawConfig( | |
| sequence_len=64, | |
| vocab_size=8192, | |
| n_layer=int(os.environ["HYDRA_N_LAYER"]), | |
| d_model=int(os.environ["HYDRA_D_MODEL"]), | |
| d_state=64, | |
| headdim=32, | |
| n_heads=8, | |
| expand=2, | |
| engram_n_columns=1024, | |
| engram_key_dim=64, | |
| engram_layer_idx=1, | |
| sdr_n_bits=16384, | |
| sdr_target_active=327, | |
| sdr_delta_rank=32, | |
| sdr_som_warmup=500, | |
| sdr_som_interval=100, | |
| htm_n_columns=2048, | |
| htm_cells_per_column=32, | |
| mtp_k=int(os.environ["HYDRA_MTP_K"]), | |
| mtp_weight_decay=0.5, | |
| ) | |
| print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} " | |
| f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}") | |
| torch.manual_seed(0) | |
| model = PostSemClawModel(cfg).to(device) | |
| model.init_weights() | |
| model.train() | |
| # ---- Enumerate params & optimizer group assignment ---- | |
| all_params = list(model.named_parameters()) | |
| print(f"[probe] total named parameters: {len(all_params)}") | |
| # Build optimizer to check group coverage (no step, no zero_grad). | |
| opt = model.setup_optimizer() | |
| grouped_ids: set[int] = set() | |
| for group in opt.param_groups: | |
| for p in group["params"]: | |
| grouped_ids.add(id(p)) | |
| unique_param_ids = {id(p) for _, p in all_params} | |
| missing_from_opt = unique_param_ids - grouped_ids | |
| print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}") | |
| if missing_from_opt: | |
| print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups") | |
| # Tied weight check. | |
| tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr() | |
| print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}") | |
| # ---- One forward + backward under bf16 autocast ---- | |
| B, T = 1, 64 | |
| idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device) | |
| tgt = torch.roll(idx, -1, dims=1) | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| loss = model(idx, targets=tgt) | |
| print(f"[probe] fwd loss = {float(loss.detach()):.4f}") | |
| loss.backward() | |
| torch.cuda.synchronize() | |
| # ---- Report ---- | |
| blockers: list[str] = [] | |
| zero_grads: list[str] = [] | |
| unexpected_frozen: list[str] = [] | |
| not_in_opt: list[str] = [] | |
| rows: list[tuple[str, tuple, str, bool, bool, float, float]] = [] | |
| for name, p in all_params: | |
| grad_is_none = p.grad is None | |
| if p.requires_grad and grad_is_none: | |
| blockers.append(name) | |
| rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), | |
| p.requires_grad, True, float("nan"), float("nan"))) | |
| continue | |
| if not p.requires_grad: | |
| unexpected_frozen.append(name) | |
| rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), | |
| False, True, float("nan"), float("nan"))) | |
| continue | |
| g = p.grad.detach().float() | |
| abs_mean = float(g.abs().mean().item()) | |
| norm = float(g.norm().item()) | |
| if abs_mean == 0.0 and norm == 0.0: | |
| zero_grads.append(name) | |
| if id(p) not in grouped_ids: | |
| not_in_opt.append(name) | |
| rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""), | |
| p.requires_grad, False, abs_mean, norm)) | |
| # Pretty table | |
| print("\n[probe] per-parameter grad table:") | |
| print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}") | |
| for name, shape, dtype, rg, none, mean, norm in rows: | |
| shape_s = "x".join(str(s) for s in shape) | |
| rg_s = "Y" if rg else "N" | |
| none_s = "Y" if none else "N" | |
| if none: | |
| mean_s, norm_s = " nan ", " nan " | |
| else: | |
| mean_s = f"{mean:>10.3e}" | |
| norm_s = f"{norm:>10.3e}" | |
| print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}") | |
| # Identity checks | |
| print("\n[probe] identity checks:") | |
| print(f" id(wte.weight) = {id(model.wte.weight)}") | |
| print(f" id(lm_head.weight) = {id(model.lm_head.weight)}") | |
| print(f" same Python object = {model.wte.weight is model.lm_head.weight}") | |
| print(f" same storage ptr = {tied}") | |
| # Engram memory inspection | |
| print(f"\n[probe] engram.memory is nn.Parameter: " | |
| f"{isinstance(model.engram.memory, torch.nn.Parameter)}") | |
| print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}") | |
| if model.engram.memory.grad is None: | |
| print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())") | |
| else: | |
| g = model.engram.memory.grad.detach().float() | |
| print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}") | |
| # Stash flag sanity: _last_sdr should be uint8, no graph | |
| last = getattr(model, "_last_sdr", None) | |
| if last is not None: | |
| print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}") | |
| else: | |
| print("\n[probe] model._last_sdr is None (fwd didn't stash — ok if path changed)") | |
| # Summary | |
| print("\n[probe] ============ SUMMARY ============") | |
| print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}") | |
| for n in blockers: | |
| print(f" - {n}") | |
| print(f" WARNINGS (grad is literally zero): {len(zero_grads)}") | |
| for n in zero_grads: | |
| print(f" - {n}") | |
| print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}") | |
| for n in unexpected_frozen: | |
| print(f" - {n}") | |
| print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}") | |
| for n in not_in_opt: | |
| print(f" - {n}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |