File size: 7,405 Bytes
22741d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Gradient flow probe for PostSemClawModel.

READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
step an optimizer. Runs one forward + backward and reports, per-parameter:

  name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm

Severity classification at the bottom:
  BLOCKER  β€” requires_grad=True but p.grad is None (disconnected from graph)
  WARNING  β€” grad present but literally zero (ops cancel, wd_init, etc.)
  WARNING  β€” requires_grad=True but param missing from every optimizer group
  OK       β€” everything else

Usage:
    .venv/bin/python -u scripts/grad_probe.py
"""

from __future__ import annotations

import os
import sys
from pathlib import Path

# Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
# resolve when we run from any cwd). Probe is intentionally a thin wrapper.
HERE = Path(__file__).resolve().parent
ROOT = HERE.parent
sys.path.insert(0, str(ROOT))

# Small model config to keep the probe fast (still exercises every component).
# K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
os.environ.setdefault("HYDRA_D_MODEL", "256")
os.environ.setdefault("HYDRA_N_LAYER", "4")
os.environ.setdefault("HYDRA_MTP_K", "4")

import torch  # noqa: E402

from train import PostSemClawModel, PostSemClawConfig  # noqa: E402


def main() -> int:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
        return 2

    cfg = PostSemClawConfig(
        sequence_len=64,
        vocab_size=8192,
        n_layer=int(os.environ["HYDRA_N_LAYER"]),
        d_model=int(os.environ["HYDRA_D_MODEL"]),
        d_state=64,
        headdim=32,
        n_heads=8,
        expand=2,
        engram_n_columns=1024,
        engram_key_dim=64,
        engram_layer_idx=1,
        sdr_n_bits=16384,
        sdr_target_active=327,
        sdr_delta_rank=32,
        sdr_som_warmup=500,
        sdr_som_interval=100,
        htm_n_columns=2048,
        htm_cells_per_column=32,
        mtp_k=int(os.environ["HYDRA_MTP_K"]),
        mtp_weight_decay=0.5,
    )

    print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
          f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")

    torch.manual_seed(0)
    model = PostSemClawModel(cfg).to(device)
    model.init_weights()
    model.train()

    # ---- Enumerate params & optimizer group assignment ----
    all_params = list(model.named_parameters())
    print(f"[probe] total named parameters: {len(all_params)}")

    # Build optimizer to check group coverage (no step, no zero_grad).
    opt = model.setup_optimizer()
    grouped_ids: set[int] = set()
    for group in opt.param_groups:
        for p in group["params"]:
            grouped_ids.add(id(p))
    unique_param_ids = {id(p) for _, p in all_params}
    missing_from_opt = unique_param_ids - grouped_ids
    print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
    if missing_from_opt:
        print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")

    # Tied weight check.
    tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
    print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")

    # ---- One forward + backward under bf16 autocast ----
    B, T = 1, 64
    idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
    tgt = torch.roll(idx, -1, dims=1)

    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        loss = model(idx, targets=tgt)
    print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
    loss.backward()
    torch.cuda.synchronize()

    # ---- Report ----
    blockers: list[str] = []
    zero_grads: list[str] = []
    unexpected_frozen: list[str] = []
    not_in_opt: list[str] = []
    rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []

    for name, p in all_params:
        grad_is_none = p.grad is None
        if p.requires_grad and grad_is_none:
            blockers.append(name)
            rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
                         p.requires_grad, True, float("nan"), float("nan")))
            continue
        if not p.requires_grad:
            unexpected_frozen.append(name)
            rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
                         False, True, float("nan"), float("nan")))
            continue
        g = p.grad.detach().float()
        abs_mean = float(g.abs().mean().item())
        norm = float(g.norm().item())
        if abs_mean == 0.0 and norm == 0.0:
            zero_grads.append(name)
        if id(p) not in grouped_ids:
            not_in_opt.append(name)
        rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
                     p.requires_grad, False, abs_mean, norm))

    # Pretty table
    print("\n[probe] per-parameter grad table:")
    print(f"  {'name':<56}  {'shape':<22}  {'dtype':<8}  rg  none  {'|g|.mean':>10}  {'|g|.norm':>10}")
    for name, shape, dtype, rg, none, mean, norm in rows:
        shape_s = "x".join(str(s) for s in shape)
        rg_s = "Y" if rg else "N"
        none_s = "Y" if none else "N"
        if none:
            mean_s, norm_s = "     nan ", "     nan "
        else:
            mean_s = f"{mean:>10.3e}"
            norm_s = f"{norm:>10.3e}"
        print(f"  {name:<56}  {shape_s:<22}  {dtype:<8}  {rg_s}  {none_s}   {mean_s}  {norm_s}")

    # Identity checks
    print("\n[probe] identity checks:")
    print(f"  id(wte.weight)       = {id(model.wte.weight)}")
    print(f"  id(lm_head.weight)   = {id(model.lm_head.weight)}")
    print(f"  same Python object   = {model.wte.weight is model.lm_head.weight}")
    print(f"  same storage ptr     = {tied}")

    # Engram memory inspection
    print(f"\n[probe] engram.memory is nn.Parameter: "
          f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
    print(f"  engram.memory.requires_grad = {model.engram.memory.requires_grad}")
    if model.engram.memory.grad is None:
        print(f"  engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
    else:
        g = model.engram.memory.grad.detach().float()
        print(f"  engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")

    # Stash flag sanity: _last_sdr should be uint8, no graph
    last = getattr(model, "_last_sdr", None)
    if last is not None:
        print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
    else:
        print("\n[probe] model._last_sdr is None (fwd didn't stash β€” ok if path changed)")

    # Summary
    print("\n[probe] ============ SUMMARY ============")
    print(f"  BLOCKERS (requires_grad but grad is None): {len(blockers)}")
    for n in blockers:
        print(f"    - {n}")
    print(f"  WARNINGS (grad is literally zero):         {len(zero_grads)}")
    for n in zero_grads:
        print(f"    - {n}")
    print(f"  WARNINGS (requires_grad=False):            {len(unexpected_frozen)}")
    for n in unexpected_frozen:
        print(f"    - {n}")
    print(f"  WARNINGS (missing from every opt group):   {len(not_in_opt)}")
    for n in not_in_opt:
        print(f"    - {n}")

    return 0


if __name__ == "__main__":
    sys.exit(main())