Jackoatmon commited on
Commit
5e5bc2d
·
verified ·
1 Parent(s): 4a1d6e7

Update Feather H200 training runtime image

Browse files
Dockerfile ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ PIP_NO_CACHE_DIR=1 \
5
+ PYTHONUNBUFFERED=1 \
6
+ CARGO_HOME=/root/.cargo \
7
+ RUSTUP_HOME=/root/.rustup \
8
+ PATH=/root/.cargo/bin:${PATH}
9
+
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ git curl ca-certificates build-essential pkg-config libssl-dev && \
12
+ rm -rf /var/lib/apt/lists/*
13
+
14
+ RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
15
+
16
+ RUN pip install --upgrade pip setuptools wheel && \
17
+ pip install \
18
+ maturin \
19
+ huggingface_hub \
20
+ datasets \
21
+ requests \
22
+ pyarrow \
23
+ rustbpe \
24
+ pandas \
25
+ tiktoken \
26
+ pydantic \
27
+ ninja \
28
+ packaging \
29
+ einops
30
+
31
+ # Mamba-3 fused CUDA kernel stack (mandatory — NO fallback allowed).
32
+ #
33
+ # We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
34
+ # and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
35
+ # on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 —
36
+ # nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
37
+ #
38
+ # Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
39
+ # - Python 3.11 (cp311) — matches PyTorch 2.6.0 image
40
+ # - CUDA 12.x wheels (cu12) — matches host CUDA 12.4
41
+ # - PyTorch 2.6 ABI (torch2.6) — exact torch match
42
+ # - cxx11abiFALSE — standard PyTorch pip build
43
+ #
44
+ # Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
45
+ # 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
46
+ # on the Space builder.
47
+ #
48
+ # Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
49
+ # for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
50
+ RUN pip install \
51
+ 'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
52
+ 'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
53
+ python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
54
+
55
+ #
56
+ # Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
57
+ # main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
58
+ # files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
59
+ # zero compiled-CUDA dependencies (verified: every import in that subtree is
60
+ # triton/torch/python — no .so files, no nvcc). So we install the v2.3.1 wheel
61
+ # (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
62
+ #
63
+ # This avoids the source-build OOM on the cpu-basic HF Space builder and the
64
+ # missing-file error the smoke hit on the last attempt.
65
+ # Download grafted mamba3 module + triton ops subtree
66
+ RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
67
+ BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
68
+ curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
69
+ mkdir -p "$SITE/ops/triton/mamba3" && \
70
+ for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
71
+ curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
72
+ done
73
+
74
+ # Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
75
+ # (pure-Triton, works). The shipped __init__.py eagerly imports
76
+ # selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
77
+ # image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
78
+ # Mamba3 (grafted from main), we skip all compiled-CUDA imports.
79
+ COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
80
+
81
+ # Structural check (no triton init — triton has no GPU on the builder)
82
+ RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
83
+ test -f "$SITE/modules/mamba3.py" && \
84
+ test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
85
+ test -s "$SITE/__init__.py" && \
86
+ echo "mamba3 graft + __init__ override verified"
87
+
88
+ # Optional tilelang for MIMO path — pure-python, cheap; SISO Mamba3 works without.
89
+ RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed — continuing"
90
+
91
+ # Triton version decision: FORCE 3.5.1 — the only version with both mamba3
92
+ # APIs (set_allocator + tl.make_tensor_descriptor). torch 2.6's _inductor
93
+ # imports AttrsDescriptor from triton.compiler.compiler which was removed in
94
+ # triton 3.4+, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
95
+ # before any torch._inductor import path runs, so the incompatibility is
96
+ # neutralized. Build-time assert verifies mamba3's two required APIs.
97
+ RUN pip install --force-reinstall --no-deps 'triton==3.5.1' && \
98
+ python -c "import triton; from triton import language as tl; \
99
+ assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
100
+ assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
101
+ print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"
102
+
103
+ WORKDIR /workspace
104
+ COPY overlay /workspace/feather
105
+ COPY entrypoint.py /app/entrypoint.py
106
+ WORKDIR /workspace/feather
107
+
108
+ RUN python -m py_compile hydra/training.py prepare.py train.py && \
109
+ bash -n scripts/run_domain_expanded_pretrain.sh
110
+
111
+ RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
112
+ export HTM_CUDA_ARCH=sm_90 && \
113
+ maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
114
+ pip install htm_rust/target/wheels/htm_rust-*.whl
115
+
116
+ CMD ["python", "/app/entrypoint.py"]
entrypoint.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import time
9
+ from http.server import BaseHTTPRequestHandler, HTTPServer
10
+ from pathlib import Path
11
+ from threading import Thread
12
+
13
+
14
+ # =============================================================================
15
+ # EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
16
+ # =============================================================================
17
+ # On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet
18
+ # initialized" on first use, because nvidia-fabricmanager on the host
19
+ # synchronizes with the container's first driver call. Once any NVML/CUDA
20
+ # call succeeds once (even just nvidia-smi), the fabric is up for the rest
21
+ # of the container lifetime.
22
+ #
23
+ # Our previous approach (wait in a subprocess before training) didn't work
24
+ # because the "initialization failed" state persisted across calls in the
25
+ # same container. The real fix: kick the driver exactly once with
26
+ # nvidia-smi, which is what successfully-working baseline containers do
27
+ # implicitly via their first torch.cuda call.
28
+ #
29
+ # Must happen BEFORE `import torch` (because any import that eagerly calls
30
+ # cudaGetDeviceCount will cache the Error 802 state).
31
+ def _early_cuda_kick() -> None:
32
+ deadline = time.time() + 120.0
33
+ attempt = 0
34
+ while time.time() < deadline:
35
+ attempt += 1
36
+ r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
37
+ if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \
38
+ or 'A100' in (r.stdout or '') or r.returncode == 0:
39
+ print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True)
40
+ break
41
+ print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
42
+ flush=True)
43
+ time.sleep(2)
44
+ # After nvidia-smi, probe torch in a subprocess so any latent error state
45
+ # doesn't leak into the main process's CUDA context.
46
+ probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
47
+ torch_deadline = time.time() + 120.0
48
+ t_attempt = 0
49
+ while time.time() < torch_deadline:
50
+ t_attempt += 1
51
+ r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
52
+ if r.returncode == 0:
53
+ print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
54
+ return
55
+ if t_attempt == 1:
56
+ print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
57
+ time.sleep(2)
58
+ print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
59
+
60
+
61
+ _early_cuda_kick()
62
+
63
+ # Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
64
+ # triton_cache_setup.py is copied next to this file by the job bash command.
65
+ try:
66
+ import triton_cache_setup as _tcs
67
+ _tcs.setup()
68
+ except ImportError:
69
+ print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
70
+
71
+ from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
72
+
73
+ REPO_ROOT = Path('/workspace/feather')
74
+ CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
75
+ LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
76
+ JOB_ID = os.environ.get('JOB_ID', 'local-job')
77
+ OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
78
+ TOKEN = os.environ.get('HF_TOKEN')
79
+ RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
80
+ APP_PORT = int(os.environ.get('PORT', '7860'))
81
+
82
+
83
+ class _HealthHandler(BaseHTTPRequestHandler):
84
+ def do_GET(self):
85
+ if self.path in ('/', '/health', '/healthz', '/ready'):
86
+ payload = {
87
+ 'status': 'ok',
88
+ 'mode': RUNTIME_MODE,
89
+ 'job_id': JOB_ID,
90
+ }
91
+ body = json.dumps(payload).encode('utf-8')
92
+ self.send_response(200)
93
+ self.send_header('Content-Type', 'application/json')
94
+ self.send_header('Content-Length', str(len(body)))
95
+ self.end_headers()
96
+ self.wfile.write(body)
97
+ return
98
+ self.send_response(404)
99
+ self.end_headers()
100
+
101
+ def log_message(self, format, *args):
102
+ return
103
+
104
+
105
+ def _start_health_server() -> HTTPServer:
106
+ server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
107
+ thread = Thread(target=server.serve_forever, daemon=True)
108
+ thread.start()
109
+ print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
110
+ return server
111
+
112
+
113
+ def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
114
+ if not path.exists():
115
+ print(f'[upload] skip missing {path}', flush=True)
116
+ return
117
+ api.upload_file(
118
+ path_or_fileobj=str(path),
119
+ path_in_repo=dest,
120
+ repo_id=OUTPUT_REPO,
121
+ repo_type='model',
122
+ )
123
+ print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
124
+
125
+
126
+ def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
127
+ """Block until CUDA is fully initialized or timeout.
128
+
129
+ On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
130
+ with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
131
+ (error 802) for the first few seconds, and any import that triggers
132
+ @triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
133
+ "0 active drivers" if it happens during that window.
134
+
135
+ We pre-init CUDA in a throwaway Python subprocess (so any error state does
136
+ not leak into the main training process) and retry until torch.cuda
137
+ reports ready.
138
+ """
139
+ import time as _t
140
+ probe = (
141
+ "import torch; "
142
+ "import sys; "
143
+ "avail = torch.cuda.is_available(); "
144
+ "count = torch.cuda.device_count() if avail else 0; "
145
+ "sys.exit(0 if (avail and count > 0) else 1)"
146
+ )
147
+ deadline = _t.time() + timeout_s
148
+ attempt = 0
149
+ while _t.time() < deadline:
150
+ attempt += 1
151
+ r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
152
+ if r.returncode == 0:
153
+ print(f'[job] CUDA ready after {attempt} probe(s)', flush=True)
154
+ return
155
+ if attempt == 1:
156
+ print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
157
+ _t.sleep(2)
158
+ print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
159
+
160
+
161
+ def run_job_mode() -> int:
162
+ os.chdir(REPO_ROOT)
163
+ os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
164
+ os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
165
+ os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
166
+ os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
167
+ os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
168
+
169
+ # CUDA readiness was kicked at module import via _early_cuda_kick. Keep
170
+ # the wait as a second safety net — no-op if CUDA already ready.
171
+ _wait_for_cuda_ready()
172
+
173
+ cmd = [
174
+ 'bash',
175
+ './scripts/run_domain_expanded_pretrain.sh',
176
+ '--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
177
+ '--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
178
+ ]
179
+ print('[job] starting Feather domain-expanded pretrain', flush=True)
180
+ print(f'[job] command={cmd}', flush=True)
181
+ proc = subprocess.run(cmd, check=False)
182
+
183
+ # Push triton compilation cache back to HF Hub for next run.
184
+ try:
185
+ import triton_cache_setup as _tcs
186
+ _tcs.teardown()
187
+ except Exception as _tcs_err:
188
+ print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
189
+
190
+ if TOKEN:
191
+ api = HfApi(token=TOKEN)
192
+ try:
193
+ api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
194
+ except Exception as e:
195
+ print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
196
+ prefix = f'jobs/{JOB_ID}'
197
+ try:
198
+ upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
199
+ upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
200
+ upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
201
+ except Exception as e:
202
+ print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
203
+ else:
204
+ print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
205
+
206
+ return proc.returncode
207
+
208
+
209
+ def run_space_mode() -> int:
210
+ server = _start_health_server()
211
+ print('[space] Feather runtime image ready', flush=True)
212
+ try:
213
+ while True:
214
+ time.sleep(3600)
215
+ finally:
216
+ server.shutdown()
217
+ server.server_close()
218
+
219
+
220
+ def main() -> int:
221
+ if RUNTIME_MODE == 'job':
222
+ return run_job_mode()
223
+ return run_space_mode()
224
+
225
+
226
+ if __name__ == '__main__':
227
+ raise SystemExit(main())
mamba_ssm_init.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
2
+ # ABI mismatch with the base image's libtorch.
3
+ #
4
+ # The upstream __init__.py eagerly imports selective_scan_cuda which fails on
5
+ # pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
6
+ # symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
7
+ # all compiled-CUDA imports here and let Mamba3 load directly.
8
+
9
+ __version__ = "2.3.1+feather-graft"
10
+
11
+ # selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
12
+ # by the Feather training path (which is Mamba3-only). If any import path
13
+ # hits this, it will get a clear AttributeError instead of an obscure ImportError.
14
+ selective_scan_fn = None
15
+ mamba_inner_fn = None
16
+
17
+ # --- triton API compatibility shims -----------------------------------------
18
+ # Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
19
+ # imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
20
+ # Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
21
+ # tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
22
+ # satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
23
+ # APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
24
+ # stub is never actually invoked at runtime because the codebase does not use
25
+ # torch.compile — but importing torch._inductor.* still requires the symbol to
26
+ # exist at module load time.
27
+ import triton as _triton # noqa: E402
28
+ if not hasattr(_triton, "set_allocator"):
29
+ def _noop_set_allocator(_fn): # pragma: no cover
30
+ return None
31
+ _triton.set_allocator = _noop_set_allocator
32
+
33
+ import triton.compiler.compiler as _tcc # noqa: E402
34
+ if not hasattr(_tcc, "AttrsDescriptor"):
35
+ class _AttrsDescriptorShim:
36
+ """Stub for torch._inductor compatibility on triton >= 3.4.
37
+ torch._inductor.runtime.hints imports this at module load but the
38
+ constructor is only called inside torch.compile paths. Accept any
39
+ args/kwargs so the import itself succeeds."""
40
+ def __init__(self, *args, **kwargs):
41
+ self.args = args
42
+ self.kwargs = kwargs
43
+
44
+ @classmethod
45
+ def from_hints(cls, *args, **kwargs):
46
+ return cls(*args, **kwargs)
47
+
48
+ _tcc.AttrsDescriptor = _AttrsDescriptorShim
49
+
50
+ # triton_key: removed in triton 3.5, used by torch._inductor.codecache for
51
+ # FxGraphCache key derivation. Return a stable string so caching still works.
52
+ if not hasattr(_tcc, "triton_key"):
53
+ def _triton_key_shim():
54
+ import triton as _t
55
+ return f"triton-{_t.__version__}-shim"
56
+ _tcc.triton_key = _triton_key_shim
57
+
58
+ # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
59
+ # for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
60
+ # so fall back to eager on any dynamo failure rather than crashing. This is
61
+ # defense-in-depth against further triton API drift.
62
+ try:
63
+ import torch._dynamo # noqa: F401 — triggers dynamo module init
64
+ torch._dynamo.config.suppress_errors = True
65
+ except Exception: # pragma: no cover
66
+ pass
67
+
68
+ # Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
69
+ from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
overlay/htm_rust/src/gpu/fused.rs ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Fused HTM megakernel launcher.
2
+ //!
3
+ //! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
4
+ //! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
5
+ //! the kernel design and the cross-block coherence strategy (grid barrier
6
+ //! via device counter with all blocks concurrently resident).
7
+ //!
8
+ //! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
9
+ //! probes the device SM count at construction and caps grid_dim.x
10
+ //! accordingly — otherwise the grid barrier deadlocks.
11
+ //!
12
+ //! Semantic change from the top-K pipeline: activation is per-column
13
+ //! threshold-based (local lateral inhibition) instead of global top-K.
14
+ //! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
15
+ //! the sparsity target. This is a real architectural change and is
16
+ //! documented in `docs/GPU_HTM.md`.
17
+
18
+ #![cfg(feature = "gpu")]
19
+
20
+ use std::ffi::CString;
21
+ use std::sync::Arc;
22
+
23
+ use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
24
+ LaunchConfig};
25
+ use cudarc::nvrtc::Ptx;
26
+
27
+ use super::sp_gpu::SpatialPoolerGpu;
28
+ use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
29
+
30
+ const PTX_HTM_FUSED: &str =
31
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
32
+
33
+ /// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
34
+ ///
35
+ /// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
36
+ /// C-side `FusedPtrs` still has the field at the same byte offset; removing
37
+ /// it here would shift all subsequent fields and break the layout. Worker A
38
+ /// will eventually delete the field from both sides once the kernel is
39
+ /// updated; until then we zero it.
40
+ #[repr(C)]
41
+ #[derive(Clone, Copy)]
42
+ pub struct FusedPtrs {
43
+ pub syn_bit: u64,
44
+ pub syn_perm: u64,
45
+ pub boost: u64,
46
+ pub active_duty: u64,
47
+ pub inhibition_threshold: u64,
48
+ pub seg_cell_id: u64,
49
+ pub seg_syn_count: u64,
50
+ pub syn_presyn: u64,
51
+ pub tm_syn_perm: u64,
52
+ pub cell_seg_count: u64,
53
+ pub cell_active_a: u64,
54
+ pub cell_active_b: u64,
55
+ pub cell_winner_a: u64,
56
+ pub cell_winner_b: u64,
57
+ pub inputs: u64,
58
+ pub cols_out: u64,
59
+ pub anom_out: u64,
60
+ /// ABI-compat dummy — always 0. No device memory is allocated for this
61
+ /// field; the cluster barrier replaces the old software DLB barrier.
62
+ pub barrier_counters: u64,
63
+ pub step_scratch: u64,
64
+ }
65
+
66
+ unsafe impl DeviceRepr for FusedPtrs {}
67
+
68
+ /// Launch-time config — matches C-side `FusedConfig` 1:1.
69
+ #[repr(C)]
70
+ #[derive(Clone, Copy)]
71
+ pub struct FusedConfig {
72
+ pub input_bits: u32,
73
+ pub n_columns: u32,
74
+ pub synapses_per_col: u32,
75
+ pub conn_thr: f32,
76
+ pub sp_inc: f32,
77
+ pub sp_dec: f32,
78
+ pub sparsity_target: f32,
79
+ pub duty_alpha: f32,
80
+ pub thr_adapt_rate: f32,
81
+ pub cells_per_column: u32,
82
+ pub n_cells: u32,
83
+ pub bits_words: u32,
84
+ pub max_segments_per_cell: u32,
85
+ pub synapses_per_segment: u32,
86
+ pub activation_threshold: u32,
87
+ pub learning_threshold: u32,
88
+ pub max_new_synapses: u32,
89
+ pub conn_thr_i16: i32,
90
+ pub perm_inc_i16: i32,
91
+ pub perm_dec_i16: i32,
92
+ pub predicted_seg_dec_i16: i32,
93
+ pub initial_perm_i16: i32,
94
+ pub t: u32,
95
+ pub learn: u32,
96
+ pub iter_seed: u32,
97
+ pub cooperative_grid_sync: u32,
98
+ }
99
+
100
+ unsafe impl DeviceRepr for FusedConfig {}
101
+
102
+ /// Cluster launch parameters probed at construction time.
103
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
104
+ pub(crate) struct ClusterInfo {
105
+ /// Maximum cluster size supported by this device (0 = cluster unsupported).
106
+ pub max_cluster_size: u32,
107
+ }
108
+
109
+ // There is only ONE launch mode: non-cooperative launch with Hopper Thread
110
+ // Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
111
+ // software DLB barrier and the cooperative-launch path are both removed.
112
+ // Cluster barriers replace both.
113
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
114
+ pub(crate) struct FusedLaunchPlan {
115
+ pub grid_dim_x: u32,
116
+ pub block_dim_x: u32,
117
+ pub cooperative_grid_limit: u32,
118
+ pub sm_count: u32,
119
+ }
120
+
121
+ fn fused_grid_cap_override() -> Option<u32> {
122
+ std::env::var("HTM_FUSED_GRID_CAP")
123
+ .ok()
124
+ .and_then(|s| s.parse::<u32>().ok())
125
+ .map(|v| v.max(1))
126
+ }
127
+
128
+ pub(crate) fn plan_fused_launch(
129
+ sm_count: u32,
130
+ cooperative_supported: bool,
131
+ cooperative_grid_limit: u32,
132
+ grid_cap_override: Option<u32>,
133
+ ) -> Result<FusedLaunchPlan, String> {
134
+ let sm_count = sm_count.max(1);
135
+ let block_dim_x = 1024u32;
136
+
137
+ // Cluster launch path: cooperative launch is not required. Keep the probe
138
+ // result for residency estimation only.
139
+ if !cooperative_supported {
140
+ eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
141
+ }
142
+
143
+ // Cluster constraint: grid_dim_x must equal the cluster size (16) so that
144
+ // each region maps to exactly one cluster. `HTM_FUSED_GRID_CAP` can lower
145
+ // this for debugging but should not exceed 16 for cluster correctness.
146
+ let default_grid_cap = 16u32;
147
+ let grid_cap = grid_cap_override.unwrap_or(default_grid_cap).min(16);
148
+ let resident_bound = if cooperative_grid_limit > 0 {
149
+ cooperative_grid_limit.max(sm_count * 2)
150
+ } else {
151
+ sm_count * 2
152
+ };
153
+ Ok(FusedLaunchPlan {
154
+ grid_dim_x: resident_bound.min(grid_cap).max(1),
155
+ block_dim_x,
156
+ cooperative_grid_limit: resident_bound,
157
+ sm_count,
158
+ })
159
+ }
160
+
161
+ pub(super) struct RawFusedKernel {
162
+ module: sys::CUmodule,
163
+ pub(super) function: sys::CUfunction,
164
+ pub(super) function_batched: sys::CUfunction,
165
+ }
166
+
167
+ unsafe impl Send for RawFusedKernel {}
168
+ unsafe impl Sync for RawFusedKernel {}
169
+
170
+ impl Drop for RawFusedKernel {
171
+ fn drop(&mut self) {
172
+ unsafe {
173
+ let _ = result::module::unload(self.module);
174
+ }
175
+ }
176
+ }
177
+
178
+ /// Owns fused-path-only device state:
179
+ /// - per-column inhibition threshold (replaces global top-K)
180
+ /// - ping-pong cell_active/cell_winner bitsets
181
+ /// - step_scratch (n_active, n_unpred per timestep)
182
+ /// - cluster launch capability info
183
+ pub struct FusedState {
184
+ dev: Arc<CudaDevice>,
185
+ pub(super) raw_kernel: RawFusedKernel,
186
+
187
+ pub inhibition_threshold: CudaSlice<f32>,
188
+ pub cell_active_bits_a: CudaSlice<u32>,
189
+ pub cell_active_bits_b: CudaSlice<u32>,
190
+ pub cell_winner_bits_a: CudaSlice<u32>,
191
+ pub cell_winner_bits_b: CudaSlice<u32>,
192
+ pub step_scratch: CudaSlice<u32>, // length 6
193
+
194
+ pub grid_dim_x: u32,
195
+ pub block_dim_x: u32,
196
+ pub cooperative_grid_limit: u32,
197
+ pub iter_counter: u32,
198
+
199
+ /// Hopper cluster launch capability (0 = unsupported).
200
+ pub cluster_info: ClusterInfo,
201
+
202
+ // Config mirror (read-only after init).
203
+ #[allow(dead_code)]
204
+ pub initial_threshold: f32,
205
+ }
206
+
207
+ impl FusedState {
208
+ pub fn new(
209
+ dev: Arc<CudaDevice>,
210
+ n_columns: usize,
211
+ cells_per_column: usize,
212
+ initial_threshold: f32,
213
+ ) -> Result<Self, DriverError> {
214
+ let n_cells = n_columns * cells_per_column;
215
+ assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
216
+ let bits_words = n_cells / 32;
217
+
218
+ let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
219
+ let init_vec = vec![initial_threshold; n_columns];
220
+ dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?;
221
+
222
+ let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
223
+ let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
224
+ let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
225
+ let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
226
+ let step_scratch = dev.alloc_zeros::<u32>(6)?;
227
+
228
+ unsafe {
229
+ result::ctx::set_current(*dev.cu_primary_ctx())?;
230
+ }
231
+ if dev.get_func("htm_fused", "htm_fused_step").is_none() {
232
+ dev.load_ptx(
233
+ Ptx::from_src(PTX_HTM_FUSED),
234
+ "htm_fused",
235
+ &["htm_fused_step", "htm_fused_step_batched"],
236
+ )?;
237
+ }
238
+ let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes");
239
+ let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?;
240
+ let function = unsafe {
241
+ result::module::get_function(module, CString::new("htm_fused_step").unwrap())
242
+ }?;
243
+ let function_batched = unsafe {
244
+ result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap())
245
+ }?;
246
+
247
+ // Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in).
248
+ // Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on
249
+ // every launched kernel function, otherwise cuLaunchKernelEx rejects
250
+ // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
251
+ unsafe {
252
+ let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
253
+ // Ignore errors: older CUDA may lack the attribute, in which case
254
+ // only portable sizes (<= 8) work — plan_fused_launch caps at 8.
255
+ let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
256
+ let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1);
257
+ }
258
+
259
+ // Probe SM count.
260
+ let sm_count = match dev.attribute(
261
+ cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
262
+ ) {
263
+ Ok(v) => v as u32,
264
+ Err(_) => 16u32,
265
+ };
266
+
267
+ // T1: Probe Hopper cluster launch capability.
268
+ let max_cluster_size = match dev.attribute(
269
+ cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
270
+ ) {
271
+ Ok(v) if v > 0 => {
272
+ // H200/sm_90a supports up to 16 blocks per cluster.
273
+ // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
274
+ // Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster).
275
+ 16u32
276
+ }
277
+ _ => 0u32,
278
+ };
279
+ eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
280
+ let cluster_info = ClusterInfo { max_cluster_size };
281
+
282
+ let cooperative_supported = matches!(
283
+ dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
284
+ Ok(v) if v > 0
285
+ );
286
+ let cooperative_grid_limit = if cooperative_supported {
287
+ let blocks_per_sm = unsafe {
288
+ result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0)
289
+ }
290
+ .ok()
291
+ .map(|v| v.max(0) as u32)
292
+ .unwrap_or(0);
293
+ sm_count.saturating_mul(blocks_per_sm)
294
+ } else {
295
+ 0
296
+ };
297
+ let launch_plan = plan_fused_launch(
298
+ sm_count,
299
+ cooperative_supported,
300
+ cooperative_grid_limit,
301
+ fused_grid_cap_override(),
302
+ )
303
+ .map_err(|msg| {
304
+ // Surface as a CUDA-ish error so callers can propagate.
305
+ eprintln!("[htm_rust] FATAL: {msg}");
306
+ DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
307
+ })?;
308
+
309
+ eprintln!(
310
+ "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
311
+ launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
312
+ cluster_info.max_cluster_size,
313
+ );
314
+
315
+ Ok(Self {
316
+ dev,
317
+ raw_kernel: RawFusedKernel { module, function, function_batched },
318
+ inhibition_threshold,
319
+ cell_active_bits_a,
320
+ cell_active_bits_b,
321
+ cell_winner_bits_a,
322
+ cell_winner_bits_b,
323
+ step_scratch,
324
+ grid_dim_x: launch_plan.grid_dim_x,
325
+ block_dim_x: launch_plan.block_dim_x,
326
+ cooperative_grid_limit: launch_plan.cooperative_grid_limit,
327
+ iter_counter: 0,
328
+ cluster_info,
329
+ initial_threshold,
330
+ })
331
+ }
332
+
333
+ /// Reset fused state. Called at region.reset().
334
+ pub fn reset(&mut self) -> Result<(), DriverError> {
335
+ self.dev.memset_zeros(&mut self.cell_active_bits_a)?;
336
+ self.dev.memset_zeros(&mut self.cell_active_bits_b)?;
337
+ self.dev.memset_zeros(&mut self.cell_winner_bits_a)?;
338
+ self.dev.memset_zeros(&mut self.cell_winner_bits_b)?;
339
+ self.dev.memset_zeros(&mut self.step_scratch)?;
340
+ // Do NOT reset inhibition_threshold — it's learned state. A hard
341
+ // reset of TM state should NOT forget the sparsity calibration.
342
+ Ok(())
343
+ }
344
+ }
345
+
346
+ /// Launch the fused megakernel. Processes all T timesteps in one kernel.
347
+ ///
348
+ /// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)`
349
+ /// when the device supports cluster launch, otherwise falls back to a plain
350
+ /// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the
351
+ /// entire grid fits in one cluster.
352
+ #[allow(clippy::too_many_arguments)]
353
+ pub fn launch_fused(
354
+ sp: &mut SpatialPoolerGpu,
355
+ tm: &mut TemporalMemoryGpu,
356
+ fused: &mut FusedState,
357
+ inputs_flat: &CudaSlice<u8>,
358
+ cols_out: &mut CudaSlice<u8>,
359
+ anom_out: &mut CudaSlice<f32>,
360
+ t: usize,
361
+ input_bits: usize,
362
+ learn: bool,
363
+ ) -> Result<(), DriverError> {
364
+ // Reset step_scratch before each launch (safe re-entry).
365
+ sp.dev_ref().memset_zeros(&mut fused.step_scratch)?;
366
+
367
+ fused.iter_counter = fused.iter_counter.wrapping_add(1);
368
+
369
+ let cfg = FusedConfig {
370
+ input_bits: input_bits as u32,
371
+ n_columns: sp.n_columns_accessor() as u32,
372
+ synapses_per_col: sp.synapses_per_col_accessor() as u32,
373
+ conn_thr: sp.conn_thr_accessor(),
374
+ sp_inc: sp.inc_accessor(),
375
+ sp_dec: sp.dec_accessor(),
376
+ sparsity_target: sp.sparsity_accessor(),
377
+ duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0),
378
+ thr_adapt_rate: 0.001f32,
379
+ cells_per_column: tm.cells_per_column as u32,
380
+ n_cells: tm.n_cells as u32,
381
+ bits_words: tm.bits_words as u32,
382
+ max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
383
+ synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
384
+ activation_threshold: tm.activation_threshold,
385
+ learning_threshold: tm.learning_threshold,
386
+ max_new_synapses: tm.max_new_synapse_count,
387
+ conn_thr_i16: tm.conn_thr_i16 as i32,
388
+ perm_inc_i16: tm.perm_inc_i16 as i32,
389
+ perm_dec_i16: tm.perm_dec_i16 as i32,
390
+ predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32,
391
+ initial_perm_i16: tm.initial_perm_i16 as i32,
392
+ t: t as u32,
393
+ learn: if learn { 1 } else { 0 },
394
+ iter_seed: fused.iter_counter,
395
+ cooperative_grid_sync: 1,
396
+ };
397
+
398
+ let ptrs = FusedPtrs {
399
+ syn_bit: *sp.syn_bit_accessor().device_ptr(),
400
+ syn_perm: *sp.syn_perm_accessor().device_ptr(),
401
+ boost: *sp.boost_accessor().device_ptr(),
402
+ active_duty: *sp.active_duty_accessor().device_ptr(),
403
+ inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
404
+ seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(),
405
+ seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(),
406
+ syn_presyn: *tm.syn_presyn_accessor().device_ptr(),
407
+ tm_syn_perm: *tm.syn_perm_accessor().device_ptr(),
408
+ cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(),
409
+ cell_active_a: *fused.cell_active_bits_a.device_ptr(),
410
+ cell_active_b: *fused.cell_active_bits_b.device_ptr(),
411
+ cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
412
+ cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
413
+ inputs: *inputs_flat.device_ptr(),
414
+ cols_out: *cols_out.device_ptr(),
415
+ anom_out: *anom_out.device_ptr(),
416
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
417
+ step_scratch: *fused.step_scratch.device_ptr(),
418
+ };
419
+
420
+ let grid_x = fused.grid_dim_x;
421
+ let block_x = fused.block_dim_x;
422
+ let cu_stream = *sp.dev_ref().cu_stream();
423
+ let use_cluster = fused.cluster_info.max_cluster_size > 0;
424
+
425
+ unsafe {
426
+ result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?;
427
+ let mut kernel_params: [*mut std::ffi::c_void; 2] = [
428
+ (&ptrs as *const FusedPtrs).cast_mut().cast(),
429
+ (&cfg as *const FusedConfig).cast_mut().cast(),
430
+ ];
431
+
432
+ if use_cluster {
433
+ // T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION.
434
+ // cluster_dim=(16,1,1) maps the entire single-region grid into one cluster.
435
+ let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
436
+ attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
437
+ attr.value.clusterDim.x = 16;
438
+ attr.value.clusterDim.y = 1;
439
+ attr.value.clusterDim.z = 1;
440
+
441
+ let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
442
+ launch_cfg.gridDimX = grid_x;
443
+ launch_cfg.gridDimY = 1;
444
+ launch_cfg.gridDimZ = 1;
445
+ launch_cfg.blockDimX = block_x;
446
+ launch_cfg.blockDimY = 1;
447
+ launch_cfg.blockDimZ = 1;
448
+ launch_cfg.sharedMemBytes = 0;
449
+ launch_cfg.hStream = cu_stream;
450
+ launch_cfg.numAttrs = 1;
451
+ launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
452
+
453
+ let ret = sys::lib().cuLaunchKernelEx(
454
+ &launch_cfg as *const sys::CUlaunchConfig,
455
+ fused.raw_kernel.function,
456
+ kernel_params.as_mut_ptr(),
457
+ std::ptr::null_mut(),
458
+ );
459
+ if ret != sys::CUresult::CUDA_SUCCESS {
460
+ return Err(DriverError(ret));
461
+ }
462
+ } else {
463
+ // Fallback for devices that don't support cluster launch.
464
+ result::launch_kernel(
465
+ fused.raw_kernel.function,
466
+ (grid_x, 1, 1),
467
+ (block_x, 1, 1),
468
+ 0,
469
+ cu_stream,
470
+ &mut kernel_params,
471
+ )?;
472
+ }
473
+ }
474
+
475
+ Ok(())
476
+ }
477
+
478
+ /// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel
479
+ /// body; each block reads its region's FusedPtrs from a device-side array
480
+ /// indexed by blockIdx.y. All regions share the same config (same
481
+ /// input_bits/n_columns/etc.) so we pass one FusedConfig.
482
+ ///
483
+ /// This breaks through the CUDA cooperative-kernel device-level
484
+ /// serialization: multiple cooperative launches are serialized regardless
485
+ /// of stream, but one cooperative launch with grid.y=B processes all
486
+ /// regions in a single invocation — ~B× speedup vs B sequential launches.
487
+ #[allow(clippy::too_many_arguments)]
488
+ /// Low-level raw-pointer entry, called by PyO3 binding which holds the
489
+ /// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live,
490
+ /// uniquely-borrowed region. All regions must be distinct.
491
+ pub(super) fn launch_fused_batched_raw(
492
+ region_ptrs: &[*mut super::HTMRegionGpu],
493
+ inputs_per_region: &[u64],
494
+ cols_per_region: &[u64],
495
+ anom_per_region: &[u64],
496
+ t: usize,
497
+ input_bits: usize,
498
+ learn: bool,
499
+ ) -> Result<(), DriverError> {
500
+ let b = region_ptrs.len();
501
+ assert_eq!(inputs_per_region.len(), b);
502
+ assert_eq!(cols_per_region.len(), b);
503
+ assert_eq!(anom_per_region.len(), b);
504
+ assert!(b >= 1, "need at least one region");
505
+
506
+ // Reset per-region step_scratch before each launch.
507
+ for &rp in region_ptrs.iter() {
508
+ let r = unsafe { &mut *rp };
509
+ let dev = r.sp_gpu.dev_ref().clone();
510
+ dev.memset_zeros(&mut r.fused_state.step_scratch)?;
511
+ r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1);
512
+ }
513
+
514
+ // Shared config — all regions use identical sp/tm parameters.
515
+ let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
516
+ let r0 = unsafe { &*region_ptrs[0] };
517
+ (
518
+ r0.fused_state.grid_dim_x,
519
+ r0.fused_state.block_dim_x,
520
+ r0.fused_state.raw_kernel.function_batched,
521
+ *r0.sp_gpu.dev_ref().cu_stream(),
522
+ *r0.sp_gpu.dev_ref().cu_primary_ctx(),
523
+ )
524
+ };
525
+
526
+ let cfg = {
527
+ let r = unsafe { &*region_ptrs[0] };
528
+ FusedConfig {
529
+ input_bits: input_bits as u32,
530
+ n_columns: r.sp_gpu.n_columns_accessor() as u32,
531
+ synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
532
+ conn_thr: r.sp_gpu.conn_thr_accessor(),
533
+ sp_inc: r.sp_gpu.inc_accessor(),
534
+ sp_dec: r.sp_gpu.dec_accessor(),
535
+ sparsity_target: r.sp_gpu.sparsity_accessor(),
536
+ duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0),
537
+ thr_adapt_rate: 0.001f32,
538
+ cells_per_column: r.tm_gpu.cells_per_column as u32,
539
+ n_cells: r.tm_gpu.n_cells as u32,
540
+ bits_words: r.tm_gpu.bits_words as u32,
541
+ max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
542
+ synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
543
+ activation_threshold: r.tm_gpu.activation_threshold,
544
+ learning_threshold: r.tm_gpu.learning_threshold,
545
+ max_new_synapses: r.tm_gpu.max_new_synapse_count,
546
+ conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32,
547
+ perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32,
548
+ perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32,
549
+ predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32,
550
+ initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
551
+ t: t as u32,
552
+ learn: if learn { 1 } else { 0 },
553
+ iter_seed: r.fused_state.iter_counter,
554
+ cooperative_grid_sync: 1,
555
+ }
556
+ };
557
+
558
+ // Build B FusedPtrs per-region.
559
+ let ptrs_vec: Vec<FusedPtrs> = (0..b)
560
+ .map(|i| {
561
+ let r = unsafe { &*region_ptrs[i] };
562
+ FusedPtrs {
563
+ syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
564
+ syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
565
+ boost: *r.sp_gpu.boost_accessor().device_ptr(),
566
+ active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
567
+ inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(),
568
+ seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
569
+ seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
570
+ syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
571
+ tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
572
+ cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
573
+ cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(),
574
+ cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(),
575
+ cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(),
576
+ cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(),
577
+ inputs: inputs_per_region[i],
578
+ cols_out: cols_per_region[i],
579
+ anom_out: anom_per_region[i],
580
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
581
+ step_scratch: *r.fused_state.step_scratch.device_ptr(),
582
+ }
583
+ })
584
+ .collect();
585
+
586
+ // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
587
+ // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
588
+ let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone();
589
+ let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?;
590
+ let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr();
591
+
592
+ // T10: Cluster launch for batched regions.
593
+ // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
594
+ // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
595
+ // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
596
+ let use_cluster = {
597
+ let r0 = unsafe { &*region_ptrs[0] };
598
+ r0.fused_state.cluster_info.max_cluster_size > 0
599
+ };
600
+
601
+ unsafe {
602
+ result::ctx::set_current(cu_ctx)?;
603
+ let mut kernel_params: [*mut std::ffi::c_void; 2] = [
604
+ (&ptrs_dev_ptr as *const u64).cast_mut().cast(),
605
+ (&cfg as *const FusedConfig).cast_mut().cast(),
606
+ ];
607
+
608
+ if use_cluster {
609
+ let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
610
+ attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
611
+ attr.value.clusterDim.x = 16;
612
+ attr.value.clusterDim.y = 1;
613
+ attr.value.clusterDim.z = 1;
614
+
615
+ let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
616
+ launch_cfg.gridDimX = grid_x;
617
+ launch_cfg.gridDimY = b as u32;
618
+ launch_cfg.gridDimZ = 1;
619
+ launch_cfg.blockDimX = block_x;
620
+ launch_cfg.blockDimY = 1;
621
+ launch_cfg.blockDimZ = 1;
622
+ launch_cfg.sharedMemBytes = 0;
623
+ launch_cfg.hStream = cu_stream;
624
+ launch_cfg.numAttrs = 1;
625
+ launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
626
+
627
+ let ret = sys::lib().cuLaunchKernelEx(
628
+ &launch_cfg as *const sys::CUlaunchConfig,
629
+ function_batched,
630
+ kernel_params.as_mut_ptr(),
631
+ std::ptr::null_mut(),
632
+ );
633
+ if ret != sys::CUresult::CUDA_SUCCESS {
634
+ return Err(DriverError(ret));
635
+ }
636
+ } else {
637
+ // Fallback: plain non-cooperative launch for non-Hopper devices.
638
+ result::launch_kernel(
639
+ function_batched,
640
+ (grid_x, b as u32, 1),
641
+ (block_x, 1, 1),
642
+ 0,
643
+ cu_stream,
644
+ &mut kernel_params,
645
+ )?;
646
+ }
647
+ }
648
+
649
+ Ok(())
650
+ }
overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Fused HTM megakernel — SP + TM, all T timesteps in a single launch.
2
+ //
3
+ // Design rationale:
4
+ // - Global top-K column selection requires cross-block synchronization at
5
+ // every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
6
+ // - Replace with per-column threshold activation using local lateral
7
+ // inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
8
+ // Threshold is a per-column running-EMA learned scalar that steers the
9
+ // column's long-run activation rate toward the global sparsity target.
10
+ // - This is biologically grounded (GABAergic local inhibition) and supported
11
+ // by HTM theory (duty-cycle boost already drives this loop; we just
12
+ // change which lever the EMA pulls).
13
+ //
14
+ // Launch shape:
15
+ // grid = min(device SM count, 16) // hard cap — see below
16
+ // block = 1024 threads = 32 warps
17
+ // Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
18
+ //
19
+ // Cross-block coherence:
20
+ // - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
21
+ // read _b; reversed at odd t.
22
+ // - Preferred path: cooperative launch + hardware whole-grid sync.
23
+ // - Fallback path: software 3-slot rotating grid barrier for devices/drivers
24
+ // that cannot do cooperative launch.
25
+ //
26
+ // 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
27
+ // cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
28
+ // 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
29
+ // leaving scheduled blocks spinning on the software grid barrier waiting for
30
+ // peer blocks that would never run. 16 blocks is below any realistic residency
31
+ // floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
32
+ // memory bandwidth on the spatial-pooler stage.
33
+ //
34
+ // Kernel signature uses struct-by-value for pointers and config to stay
35
+ // inside cudarc's launch-arg count limit.
36
+
37
+ #include <cooperative_groups.h>
38
+ #include <cooperative_groups/memcpy_async.h>
39
+
40
+ namespace cg = cooperative_groups;
41
+
42
+ // Maximum columns owned per cluster-block in DSMEM.
43
+ // Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
44
+ // At cluster_size=16: supports up to 256*16=4096 columns.
45
+ // Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM —
46
+ // well under the 228 KB H200 shared-memory cap.
47
+ #define COLS_PER_CLUSTER_BLOCK_MAX 256u
48
+
49
+ // Maximum input_bits supported by the TMA-multicast staging tile.
50
+ // At 32 KB this covers the production SDR width (16384 bits) with 2× headroom.
51
+ // Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB —
52
+ // well under the 228 KB H200 limit.
53
+ //
54
+ // Expected speedup from TMA multicast input staging (T9/T11):
55
+ // - Without staging: 16 SMs × T × (input_bits GMEM reads per timestep)
56
+ // - With staging: 1 TMA DMA per timestep, shared reads from L1 thereafter
57
+ // - Theoretical DRAM bandwidth reduction: ~16× on input reads
58
+ // - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
59
+ #define INPUT_BITS_MAX 32768u
60
+
61
+ extern "C" {
62
+
63
+ struct FusedPtrs {
64
+ unsigned long long syn_bit;
65
+ unsigned long long syn_perm;
66
+ unsigned long long boost;
67
+ unsigned long long active_duty;
68
+ unsigned long long inhibition_threshold;
69
+ unsigned long long seg_cell_id;
70
+ unsigned long long seg_syn_count;
71
+ unsigned long long syn_presyn;
72
+ unsigned long long tm_syn_perm;
73
+ unsigned long long cell_seg_count;
74
+ unsigned long long cell_active_a;
75
+ unsigned long long cell_active_b;
76
+ unsigned long long cell_winner_a;
77
+ unsigned long long cell_winner_b;
78
+ unsigned long long inputs;
79
+ unsigned long long cols_out;
80
+ unsigned long long anom_out;
81
+ unsigned long long barrier_counters;
82
+ unsigned long long step_scratch;
83
+ };
84
+
85
+ struct FusedConfig {
86
+ // SP constants
87
+ unsigned int input_bits;
88
+ unsigned int n_columns;
89
+ unsigned int synapses_per_col;
90
+ float conn_thr;
91
+ float sp_inc;
92
+ float sp_dec;
93
+ float sparsity_target;
94
+ float duty_alpha;
95
+ float thr_adapt_rate;
96
+ // TM constants
97
+ unsigned int cells_per_column;
98
+ unsigned int n_cells;
99
+ unsigned int bits_words;
100
+ unsigned int max_segments_per_cell;
101
+ unsigned int synapses_per_segment;
102
+ unsigned int activation_threshold;
103
+ unsigned int learning_threshold;
104
+ unsigned int max_new_synapses;
105
+ int conn_thr_i16;
106
+ int perm_inc_i16;
107
+ int perm_dec_i16;
108
+ int predicted_seg_dec_i16;
109
+ int initial_perm_i16;
110
+ // Loop constants
111
+ unsigned int T;
112
+ unsigned int learn;
113
+ unsigned int iter_seed;
114
+ unsigned int cooperative_grid_sync;
115
+ };
116
+
117
+ // Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
118
+ // Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
119
+ //
120
+ // cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
121
+ // in ~10-40 ns inside the cluster, with no device-level serialization.
122
+ // Multiple clusters (one per HTM region) run fully concurrently — bounded
123
+ // only by SM count (8 clusters × 16 SMs = 128 ≤ 132 on H200).
124
+ //
125
+ // The flags / expected / phase / cooperative_grid_sync parameters are kept
126
+ // in the signature for call-site compatibility but are unused.
127
+ __device__ static inline void fused_grid_barrier(cg::grid_group /* grid */,
128
+ unsigned int * /* flags — unused */,
129
+ unsigned int /* expected — unused */,
130
+ unsigned int /* phase — unused */,
131
+ unsigned int /* cooperative_grid_sync — unused */) {
132
+ auto cluster = cg::this_cluster();
133
+ cluster.sync();
134
+ }
135
+
136
+ __device__ static inline unsigned int warp_sum_u32(unsigned int v) {
137
+ for (int off = 16; off > 0; off >>= 1) {
138
+ v += __shfl_down_sync(0xffffffffu, v, off);
139
+ }
140
+ return v;
141
+ }
142
+
143
+ // Core kernel body — works for both single-region and batched launches.
144
+ // Single-region: caller passes the one FusedPtrs struct.
145
+ // Batched: each block reads its region's FusedPtrs via blockIdx.y before
146
+ // calling this. State is independent per region (each region owns its own
147
+ // GPU buffers); grid.sync() is the only cross-block primitive and it
148
+ // spans ALL blocks in the grid (harmless over-sync across regions).
149
+ __device__ static inline
150
+ void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
151
+ cg::grid_group grid = cg::this_grid();
152
+ // Cast pointers.
153
+ const unsigned int * __restrict__ syn_bit = (const unsigned int*)P.syn_bit;
154
+ float * __restrict__ syn_perm = (float*)P.syn_perm;
155
+ float * __restrict__ boost = (float*)P.boost;
156
+ float * __restrict__ active_duty = (float*)P.active_duty;
157
+ float * __restrict__ inhibition_threshold = (float*)P.inhibition_threshold;
158
+ unsigned int * __restrict__ seg_cell_id = (unsigned int*)P.seg_cell_id;
159
+ unsigned int * __restrict__ seg_syn_count = (unsigned int*)P.seg_syn_count;
160
+ unsigned int * __restrict__ syn_presyn = (unsigned int*)P.syn_presyn;
161
+ short * __restrict__ tm_syn_perm = (short*)P.tm_syn_perm;
162
+ unsigned int * __restrict__ cell_seg_count = (unsigned int*)P.cell_seg_count;
163
+ unsigned int * __restrict__ cell_active_a = (unsigned int*)P.cell_active_a;
164
+ unsigned int * __restrict__ cell_active_b = (unsigned int*)P.cell_active_b;
165
+ unsigned int * __restrict__ cell_winner_a = (unsigned int*)P.cell_winner_a;
166
+ unsigned int * __restrict__ cell_winner_b = (unsigned int*)P.cell_winner_b;
167
+ const unsigned char * __restrict__ inputs = (const unsigned char*)P.inputs;
168
+ unsigned char * __restrict__ cols_out = (unsigned char*)P.cols_out;
169
+ float * __restrict__ anom_out = (float*)P.anom_out;
170
+ unsigned int * __restrict__ barrier_counters = (unsigned int*)P.barrier_counters;
171
+ unsigned int * __restrict__ step_scratch = (unsigned int*)P.step_scratch;
172
+
173
+ const unsigned int tid = threadIdx.x;
174
+ const unsigned int lane = tid & 31u;
175
+ const unsigned int warp = tid >> 5;
176
+ const unsigned int warps_per_block = blockDim.x >> 5;
177
+ const unsigned int gwarp = blockIdx.x * warps_per_block + warp;
178
+ const unsigned int n_warps = gridDim.x * warps_per_block;
179
+
180
+ const unsigned int n_cols = cfg.n_columns;
181
+ const unsigned int col_lo = (gwarp * n_cols) / n_warps;
182
+ const unsigned int col_hi = ((gwarp + 1) * n_cols) / n_warps;
183
+
184
+ unsigned int phase = 0u;
185
+
186
+ // =========================================================
187
+ // DSMEM: Cluster-distributed shared memory for hot per-column
188
+ // state (inhibition_threshold, boost, active_duty).
189
+ //
190
+ // Each block in the cluster owns a contiguous slice of
191
+ // [my_col_start, my_col_end) columns in its own __shared__
192
+ // arrays. Any block can peer-read another block's slice via
193
+ // cluster.map_shared_rank(ptr, owner_block_rank)[offset].
194
+ //
195
+ // This eliminates 2×n_cols×T GMEM reads per forward call
196
+ // (read + potential re-read of threshold/boost/duty per timestep).
197
+ // =========================================================
198
+ auto cluster = cg::this_cluster();
199
+ const unsigned int cluster_block_rank = cluster.block_rank(); // 0..cluster_size-1
200
+ const unsigned int cluster_sz = cluster.num_blocks(); // == gridDim.x (≤16)
201
+
202
+ // Partition n_cols evenly across cluster blocks.
203
+ // Each block owns cols_per_block columns starting at my_col_start.
204
+ const unsigned int cols_per_block =
205
+ (n_cols + cluster_sz - 1u) / cluster_sz; // ceil div
206
+ const unsigned int my_col_start =
207
+ cluster_block_rank * cols_per_block;
208
+ const unsigned int my_col_end =
209
+ (my_col_start + cols_per_block < n_cols)
210
+ ? (my_col_start + cols_per_block) : n_cols; // clamp
211
+
212
+ // Cluster-distributed shared memory arrays.
213
+ // Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
214
+ // Peer blocks address into each other's smem via map_shared_rank.
215
+ __shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
216
+ __shared__ float s_boost [COLS_PER_CLUSTER_BLOCK_MAX];
217
+ __shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
218
+
219
+ // TMA multicast input staging tile (T9).
220
+ //
221
+ // On Hopper (sm_90a), cg::memcpy_async with cluster scope issues a single
222
+ // TMA DMA that multicasts the source data to all 16 SMs in the cluster
223
+ // simultaneously — replacing ~16 per-block GMEM reads per timestep with a
224
+ // single hardware DMA. After cg::wait(cluster) every SM's s_input_tile
225
+ // is populated identically without any additional DRAM traffic.
226
+ //
227
+ // Fallback: when cfg.input_bits > INPUT_BITS_MAX the tile is bypassed
228
+ // and each thread reads directly from GMEM (original path).
229
+ //
230
+ // Alignment: 16-byte aligned to satisfy TMA descriptor requirements.
231
+ __shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
232
+
233
+ // Initial GMEM → smem load (reads state from previous forward call).
234
+ // Each block loads only its own slice; tid strides across the slice.
235
+ for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
236
+ const unsigned int off = c - my_col_start;
237
+ s_inhib_thr [off] = inhibition_threshold[c];
238
+ s_boost [off] = boost[c];
239
+ s_active_duty[off] = active_duty[c];
240
+ }
241
+
242
+ // All blocks in the cluster must finish loading before any block
243
+ // starts reading peer smem inside the T-loop.
244
+ cluster.sync();
245
+
246
+ const unsigned int S = cfg.synapses_per_col;
247
+ const unsigned int cpc = cfg.cells_per_column;
248
+ const unsigned int SPS = cfg.synapses_per_segment;
249
+ const unsigned int MSC = cfg.max_segments_per_cell;
250
+
251
+ // Main timestep loop.
252
+ for (unsigned int t = 0u; t < cfg.T; t++) {
253
+ const unsigned int inp_off = t * cfg.input_bits;
254
+ const unsigned int col_base_out = t * n_cols;
255
+
256
+ unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
257
+ unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
258
+ unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
259
+ unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;
260
+
261
+ // ---- Phase 0: clear curr bitsets for my cell range ----
262
+ const unsigned int my_cell_lo = col_lo * cpc;
263
+ const unsigned int my_cell_hi = col_hi * cpc;
264
+ if (cpc == 32u) {
265
+ // Fast path: one word per column.
266
+ for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
267
+ curr_active[c] = 0u;
268
+ curr_winner[c] = 0u;
269
+ }
270
+ } else {
271
+ for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
272
+ unsigned int w = cell >> 5;
273
+ unsigned int m = 1u << (cell & 31u);
274
+ atomicAnd(&curr_active[w], ~m);
275
+ atomicAnd(&curr_winner[w], ~m);
276
+ }
277
+ }
278
+
279
+ // Block 0, lane 0, warp 0 resets step-scratch counters.
280
+ if (blockIdx.x == 0u && tid == 0u) {
281
+ step_scratch[0] = 0u;
282
+ step_scratch[1] = 0u;
283
+ }
284
+
285
+ // ---- BARRIER 1 ----
286
+ // Fence: make the above clear-bitsets + scratch writes globally
287
+ // visible before peer blocks observe "barrier arrived".
288
+ __threadfence();
289
+ fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
290
+
291
+ // =========================================================
292
+ // T9: TMA MULTICAST INPUT STAGING
293
+ //
294
+ // Issue a single cluster-scope async DMA to broadcast this
295
+ // timestep's input slice into s_input_tile across all 16 SMs
296
+ // in the cluster simultaneously. On Hopper sm_90a,
297
+ // cg::memcpy_async with cluster scope maps to the TMA
298
+ // hardware unit (cp.async.bulk.tensor multicast), reducing
299
+ // DRAM input traffic by ~16× vs each block fetching its own
300
+ // copy from GMEM.
301
+ //
302
+ // The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
303
+ // If the tile is too small (custom large input_bits), we fall
304
+ // back to per-thread GMEM reads in Stage A (identical to the
305
+ // original path; use_input_tile==false).
306
+ //
307
+ // Ordering: BARRIER 1 completes before we issue the DMA.
308
+ // The DMA completes before Stage A reads s_input_tile.
309
+ // =========================================================
310
+ const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
311
+ if (use_input_tile) {
312
+ // Thread-block scope async copy: each SM independently loads
313
+ // its own input tile from GMEM into shared memory.
314
+ //
315
+ // NOTE: CUDA 12.1's cooperative_groups::memcpy_async() rejects
316
+ // cluster_group at compile time (static_assert in async.h:171).
317
+ // True TMA multicast (single DMA for all 16 SMs in the cluster)
318
+ // would require raw PTX cp.async.bulk.tensor with multicast mode,
319
+ // which needs cuTensorMap descriptors on the host side (T11).
320
+ //
321
+ // This per-SM path still gives a meaningful win: it converts
322
+ // the original per-synapse scattered GMEM reads (random access
323
+ // pattern hitting multiple cache lines) into one sequential DMA
324
+ // per SM, improving L2 hit rate and hardware prefetcher
325
+ // effectiveness. The cluster.sync() below ensures all SMs in
326
+ // the cluster have finished loading before any SM enters Stage A.
327
+ auto tb = cg::this_thread_block();
328
+ cg::memcpy_async(tb, s_input_tile,
329
+ inputs + inp_off,
330
+ cfg.input_bits);
331
+ cg::wait(tb);
332
+ // Cluster barrier: all 16 SMs must have loaded their tile
333
+ // before any SM begins reading s_input_tile in Stage A.
334
+ cluster.sync();
335
+ }
336
+
337
+ // =========================================================
338
+ // STAGE A: Spatial Pooler
339
+ //
340
+ // Hot per-column state (boost, inhibition_threshold,
341
+ // active_duty) is served from cluster DSMEM rather than
342
+ // GMEM for each of the T timesteps. GMEM is written on
343
+ // update so state persists across forward calls.
344
+ // =========================================================
345
+ for (unsigned int c = col_lo; c < col_hi; c++) {
346
+ unsigned int base = c * S;
347
+ unsigned int local = 0u;
348
+ for (unsigned int s = lane; s < S; s += 32u) {
349
+ unsigned int b = syn_bit[base + s];
350
+ float p = syn_perm[base + s];
351
+ // T9: read from cluster-broadcast tile when available;
352
+ // fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
353
+ unsigned int inp_byte = use_input_tile
354
+ ? (unsigned int)s_input_tile[b]
355
+ : (unsigned int)inputs[inp_off + b];
356
+ unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
357
+ local += hit;
358
+ }
359
+ unsigned int overlap = warp_sum_u32(local);
360
+ overlap = __shfl_sync(0xffffffffu, overlap, 0);
361
+
362
+ // Determine which cluster block owns column c and read
363
+ // boost + threshold from that block's shared memory.
364
+ const unsigned int owner_block = c / cols_per_block;
365
+ const unsigned int owner_offset = c - owner_block * cols_per_block;
366
+
367
+ float boost_val = cluster.map_shared_rank(s_boost, owner_block)[owner_offset];
368
+ float thr = cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset];
369
+
370
+ float boosted = (float)overlap * boost_val;
371
+ unsigned int is_active = (boosted > thr) ? 1u : 0u;
372
+
373
+ if (lane == 0) {
374
+ cols_out[col_base_out + c] = (unsigned char)is_active;
375
+ if (is_active) {
376
+ atomicAdd(&step_scratch[0], 1u);
377
+ }
378
+ }
379
+
380
+ // SP learn (Hebbian) on active columns.
381
+ // T9: use tile for input reads here too.
382
+ if (cfg.learn && is_active) {
383
+ for (unsigned int s = lane; s < S; s += 32u) {
384
+ unsigned int b = syn_bit[base + s];
385
+ float p = syn_perm[base + s];
386
+ unsigned int inp_byte = use_input_tile
387
+ ? (unsigned int)s_input_tile[b]
388
+ : (unsigned int)inputs[inp_off + b];
389
+ if (inp_byte != 0u) {
390
+ p += cfg.sp_inc;
391
+ if (p > 1.0f) p = 1.0f;
392
+ } else {
393
+ p -= cfg.sp_dec;
394
+ if (p < 0.0f) p = 0.0f;
395
+ }
396
+ syn_perm[base + s] = p;
397
+ }
398
+ }
399
+
400
+ // active_duty EMA + threshold adaptation.
401
+ // Writes go to both peer DSMEM (hot path for next timestep)
402
+ // and GMEM (persistence across forward calls).
403
+ if (lane == 0) {
404
+ float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
405
+ float sample = is_active ? 1.0f : 0.0f;
406
+ ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;
407
+
408
+ // Writeback: peer smem (for next timestep read) + GMEM (persistence).
409
+ cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
410
+ active_duty[c] = ad;
411
+
412
+ // Threshold steers toward target sparsity.
413
+ float err = ad - cfg.sparsity_target;
414
+ float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
415
+ if (new_thr < 0.1f) new_thr = 0.1f;
416
+ if (new_thr > 1000.0f) new_thr = 1000.0f;
417
+
418
+ // Writeback: peer smem (for next timestep read) + GMEM (persistence).
419
+ cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
420
+ inhibition_threshold[c] = new_thr;
421
+ }
422
+ }
423
+
424
+ // ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
425
+ //
426
+ // DATA FLOW PROOF (T-loop iteration invariant):
427
+ //
428
+ // WRITE SITES (lane==0 inside Stage A per-col loop):
429
+ // Line 328: cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad
430
+ // Line 338: cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr
431
+ //
432
+ // READ SITES (Stage A of the NEXT timestep t+1):
433
+ // Line 290: cluster.map_shared_rank(s_boost, owner_block)[owner_offset] (read)
434
+ // Line 291: cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] (read)
435
+ // Line 323: cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] (read)
436
+ //
437
+ // PARTITION MISMATCH (root cause of T8 staleness):
438
+ // cols_per_block = ceil(n_cols / cluster_sz) [smem partition]
439
+ // col_lo/col_hi = floor(gwarp*n_cols/n_warps) [gwarp work partition]
440
+ // These are NOT identical — up to 1 column can spill across partition boundaries.
441
+ // Example: n_cols=1000, cluster_sz=16 → cols_per_block=63, block 1 col_lo=62
442
+ // → block 1 processes column 62 but column 62 belongs to block 0's smem slice.
443
+ // → block 1 issues a PEER WRITE to block 0's s_inhib_thr / s_active_duty.
444
+ //
445
+ // RACE WITHOUT SYNC:
446
+ // Blocks run Stage A concurrently. Block 1 writes block 0's smem at column 62.
447
+ // Block 0 may simultaneously READ s_inhib_thr[62] for its own column 62 in
448
+ // Stage A of the same timestep → concurrent peer write + local read → undefined.
449
+ // Additionally, without cluster.sync() after all peer writes complete, block 0's
450
+ // t+1 Stage A reads might observe t-1 values still cached in its smem.
451
+ //
452
+ // FIX: cluster.sync() here, AFTER Stage A's per-column loop, ensures:
453
+ // 1. All peer smem writes from this timestep are globally visible to all blocks.
454
+ // 2. No block can enter Stage B (or start t+1 Stage A) with stale smem values.
455
+ // 3. GMEM writes (lines 329, 339) are already committed to L2; __threadfence()
456
+ // below ensures they are visible to all SMs before the cluster barrier.
457
+ //
458
+ // ORDERING: write → cluster.sync() here → __threadfence() → cluster.sync() in
459
+ // fused_grid_barrier → next-timestep reads. Both visibility guarantees
460
+ // are now satisfied.
461
+ cluster.sync();
462
+
463
+ // ---- BARRIER 2: SP active_mask must be visible before TM reads ----
464
+ // Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
465
+ // writes to global memory before peers advance past this barrier.
466
+ __threadfence();
467
+ fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
468
+
469
+ // =========================================================
470
+ // STAGE B: Temporal Memory
471
+ // =========================================================
472
+ for (unsigned int c = col_lo; c < col_hi; c++) {
473
+ unsigned int col_active = cols_out[col_base_out + c];
474
+ if (col_active == 0u) continue;
475
+
476
+ unsigned int base_cell = c * cpc;
477
+ unsigned int any_predicted = 0u;
478
+ unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
479
+ unsigned int best_pot_count = 0u;
480
+
481
+ for (unsigned int k = 0u; k < cpc; k++) {
482
+ unsigned int cell = base_cell + k;
483
+ unsigned int n_segs_here = cell_seg_count[cell];
484
+ if (n_segs_here > MSC) n_segs_here = MSC;
485
+ if (n_segs_here == 0u) continue;
486
+
487
+ unsigned int seg_base_id = cell * MSC;
488
+ unsigned int cell_is_predictive = 0u;
489
+
490
+ for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
491
+ unsigned int seg = seg_base_id + ls;
492
+ unsigned int n_syn = seg_syn_count[seg];
493
+ if (n_syn == 0u) continue;
494
+ unsigned int syn_base = seg * SPS;
495
+
496
+ unsigned int l_conn = 0u;
497
+ unsigned int l_pot = 0u;
498
+ for (unsigned int s = lane; s < n_syn; s += 32u) {
499
+ unsigned int presyn = syn_presyn[syn_base + s];
500
+ unsigned int w = prev_active[presyn >> 5];
501
+ unsigned int bit = (w >> (presyn & 31u)) & 1u;
502
+ if (bit) {
503
+ l_pot += 1u;
504
+ int p = (int)tm_syn_perm[syn_base + s];
505
+ if (p >= cfg.conn_thr_i16) l_conn += 1u;
506
+ }
507
+ }
508
+ unsigned int tot_conn = warp_sum_u32(l_conn);
509
+ unsigned int tot_pot = warp_sum_u32(l_pot);
510
+ tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
511
+ tot_pot = __shfl_sync(0xffffffffu, tot_pot, 0);
512
+
513
+ if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
514
+ if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
515
+ best_pot_count = tot_pot;
516
+ best_seg_id_for_grow = seg;
517
+ }
518
+
519
+ // Reinforce predicted-and-correct segment.
520
+ if (cfg.learn && tot_conn >= cfg.activation_threshold) {
521
+ for (unsigned int s = lane; s < n_syn; s += 32u) {
522
+ unsigned int presyn = syn_presyn[syn_base + s];
523
+ unsigned int w = prev_active[presyn >> 5];
524
+ unsigned int bit = (w >> (presyn & 31u)) & 1u;
525
+ int p = (int)tm_syn_perm[syn_base + s];
526
+ if (bit) {
527
+ int np = p + cfg.perm_inc_i16;
528
+ if (np > 32767) np = 32767;
529
+ tm_syn_perm[syn_base + s] = (short)np;
530
+ } else {
531
+ int np = p - cfg.perm_dec_i16;
532
+ if (np < 0) np = 0;
533
+ tm_syn_perm[syn_base + s] = (short)np;
534
+ }
535
+ }
536
+ }
537
+ }
538
+
539
+ if (cell_is_predictive) {
540
+ any_predicted = 1u;
541
+ if (lane == 0) {
542
+ unsigned int w = cell >> 5;
543
+ unsigned int m = 1u << (cell & 31u);
544
+ atomicOr(&curr_active[w], m);
545
+ atomicOr(&curr_winner[w], m);
546
+ }
547
+ }
548
+ }
549
+
550
+ // BURST if no predicted.
551
+ if (!any_predicted) {
552
+ if (lane == 0) {
553
+ for (unsigned int k = 0u; k < cpc; k++) {
554
+ unsigned int cell = base_cell + k;
555
+ unsigned int w = cell >> 5;
556
+ unsigned int m = 1u << (cell & 31u);
557
+ atomicOr(&curr_active[w], m);
558
+ }
559
+ unsigned int win = base_cell;
560
+ unsigned int ww = win >> 5;
561
+ unsigned int wm = 1u << (win & 31u);
562
+ atomicOr(&curr_winner[ww], wm);
563
+ atomicAdd(&step_scratch[1], 1u);
564
+ }
565
+
566
+ if (cfg.learn) {
567
+ unsigned int target_seg;
568
+ unsigned int existing_syn;
569
+ if (best_seg_id_for_grow != 0xFFFFFFFFu) {
570
+ // Reuse best matching segment.
571
+ target_seg = best_seg_id_for_grow;
572
+ existing_syn = seg_syn_count[target_seg];
573
+ target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
574
+ existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);
575
+
576
+ // Reinforce its existing synapses.
577
+ unsigned int syn_base = target_seg * SPS;
578
+ for (unsigned int s = lane; s < existing_syn; s += 32u) {
579
+ unsigned int presyn = syn_presyn[syn_base + s];
580
+ unsigned int w = prev_active[presyn >> 5];
581
+ unsigned int bit = (w >> (presyn & 31u)) & 1u;
582
+ int p = (int)tm_syn_perm[syn_base + s];
583
+ if (bit) {
584
+ int np = p + cfg.perm_inc_i16;
585
+ if (np > 32767) np = 32767;
586
+ tm_syn_perm[syn_base + s] = (short)np;
587
+ } else {
588
+ int np = p - cfg.perm_dec_i16;
589
+ if (np < 0) np = 0;
590
+ tm_syn_perm[syn_base + s] = (short)np;
591
+ }
592
+ }
593
+ } else {
594
+ // Allocate new segment on winner cell (cell 0 of col).
595
+ unsigned int new_seg = 0u;
596
+ if (lane == 0) {
597
+ unsigned int winner_cell = base_cell;
598
+ unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
599
+ if (slot >= MSC) slot = slot % MSC;
600
+ new_seg = winner_cell * MSC + slot;
601
+ seg_cell_id[new_seg] = winner_cell;
602
+ seg_syn_count[new_seg] = 0u;
603
+ }
604
+ target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
605
+ existing_syn = 0u;
606
+ }
607
+
608
+ // Grow synapses to prev_winner cells — lane 0 serialized.
609
+ unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
610
+ unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
611
+ if (lane == 0 && max_grow > 0u) {
612
+ unsigned int syn_base = target_seg * SPS;
613
+ unsigned int grown = 0u;
614
+ unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
615
+ for (unsigned int w_off = 0u;
616
+ w_off < cfg.bits_words && grown < max_grow;
617
+ w_off++) {
618
+ unsigned int widx = (start_off + w_off) % cfg.bits_words;
619
+ unsigned int word = prev_winner[widx];
620
+ while (word != 0u && grown < max_grow) {
621
+ unsigned int bit_pos = __ffs(word) - 1u;
622
+ word &= ~(1u << bit_pos);
623
+ unsigned int cell_id = widx * 32u + bit_pos;
624
+ if (cell_id >= cfg.n_cells) continue;
625
+ bool exists = false;
626
+ for (unsigned int es = 0u; es < existing_syn + grown; es++) {
627
+ if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
628
+ }
629
+ if (exists) continue;
630
+ unsigned int write_idx = existing_syn + grown;
631
+ if (write_idx >= SPS) break;
632
+ syn_presyn[syn_base + write_idx] = cell_id;
633
+ tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
634
+ grown++;
635
+ }
636
+ }
637
+ if (grown > 0u) {
638
+ seg_syn_count[target_seg] = existing_syn + grown;
639
+ }
640
+ }
641
+ }
642
+ }
643
+ }
644
+
645
+ // ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
646
+ // Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
647
+ // seg_syn_count + syn_presyn before peers advance and consume them as
648
+ // prev_active/prev_winner at t+1.
649
+ __threadfence();
650
+ fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
651
+
652
+ // Write anomaly for step t.
653
+ if (blockIdx.x == 0u && tid == 0u) {
654
+ unsigned int total = step_scratch[0];
655
+ unsigned int bad = step_scratch[1];
656
+ float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
657
+ anom_out[t] = anom;
658
+ }
659
+ }
660
+ }
661
+
662
+ // Single-region kernel (legacy call site).
663
+ __global__
664
+ void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
665
+ htm_fused_step_body(P, cfg);
666
+ }
667
+
668
+ // Batched kernel: one cooperative launch for B regions. grid.y = B,
669
+ // grid.x = per-region block count. Each block reads its region's
670
+ // FusedPtrs from the device array via blockIdx.y.
671
+ __global__
672
+ void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
673
+ const FusedPtrs P = P_arr[blockIdx.y];
674
+ htm_fused_step_body(P, cfg);
675
+ }
676
+
677
+ } // extern "C"
overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Fused mean-reduction + boost-update kernel.
2
+ //
3
+ // Inputs:
4
+ // active_duty[n] (f32)
5
+ // boost_strength (f32)
6
+ //
7
+ // Output:
8
+ // boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
9
+ //
10
+ // Launch: single block (1024 threads), shared mem for reduction. At n=2048
11
+ // each thread handles 2 elements.
12
+
13
+ extern "C" __global__
14
+ void sp_boost_from_duty(
15
+ const float * __restrict__ active_duty, // (n,)
16
+ float * __restrict__ boost, // (n,) in-place out
17
+ float boost_strength,
18
+ unsigned int n
19
+ ) {
20
+ extern __shared__ float smem_raw[];
21
+ float * smem = smem_raw;
22
+ const unsigned int tid = threadIdx.x;
23
+ const unsigned int bsz = blockDim.x;
24
+
25
+ // Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
26
+ float local_sum = 0.0f;
27
+ for (unsigned int i = tid; i < n; i += bsz) {
28
+ local_sum += active_duty[i];
29
+ }
30
+ // Warp reduction.
31
+ for (int off = 16; off > 0; off >>= 1) {
32
+ local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
33
+ }
34
+ unsigned int lane = tid & 31;
35
+ unsigned int warp = tid >> 5;
36
+ if (lane == 0) smem[warp] = local_sum;
37
+ __syncthreads();
38
+
39
+ // Warp 0 reduces warp-sums.
40
+ __shared__ float mean_s;
41
+ if (warp == 0) {
42
+ unsigned int nwarps = (bsz + 31) / 32;
43
+ float v = (lane < nwarps) ? smem[lane] : 0.0f;
44
+ for (int off = 16; off > 0; off >>= 1) {
45
+ v += __shfl_down_sync(0xffffffff, v, off);
46
+ }
47
+ if (tid == 0) {
48
+ mean_s = v / (float)n;
49
+ }
50
+ }
51
+ __syncthreads();
52
+
53
+ // Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
54
+ float mean = mean_s;
55
+ for (unsigned int i = tid; i < n; i += bsz) {
56
+ float d = active_duty[i] - mean;
57
+ boost[i] = expf(-boost_strength * d);
58
+ }
59
+ }
overlay/htm_rust/src/gpu/kernels/sp_duty.cu ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Duty cycle + boost update kernel.
2
+ //
3
+ // For each column c (one thread each):
4
+ // active_sample = active_mask[c] ? 1 : 0
5
+ // overlap_sample = raw_overlap[c] >= stim_thr ? 1 : 0
6
+ // active_duty[c] = (1-alpha) * active_duty[c] + alpha * active_sample
7
+ // overlap_duty[c] = (1-alpha) * overlap_duty[c] + alpha * overlap_sample
8
+ //
9
+ // Then, if learn:
10
+ // boost[c] = exp(-boost_strength * (active_duty[c] - mean_duty))
11
+ // mean_duty is computed on the host (one reduction) and passed in.
12
+
13
+ extern "C" __global__
14
+ void sp_duty_update(
15
+ const unsigned char * __restrict__ active_mask, // (n_columns,)
16
+ const unsigned int * __restrict__ raw_overlap, // (n_columns,)
17
+ float * __restrict__ active_duty, // (n_columns,) in-place
18
+ float * __restrict__ overlap_duty, // (n_columns,) in-place
19
+ float * __restrict__ boost, // (n_columns,) in-place
20
+ float alpha,
21
+ float stim_thr,
22
+ float boost_strength, // 0 to skip boost
23
+ float mean_duty,
24
+ unsigned int learn_flag, // 0 or 1
25
+ unsigned int n_columns
26
+ ) {
27
+ unsigned int c = blockIdx.x * blockDim.x + threadIdx.x;
28
+ if (c >= n_columns) return;
29
+
30
+ float ad = active_duty[c];
31
+ float od = overlap_duty[c];
32
+
33
+ float a_sample = (active_mask[c] != 0) ? 1.0f : 0.0f;
34
+ float o_sample = ((float)raw_overlap[c] >= stim_thr) ? 1.0f : 0.0f;
35
+
36
+ ad = (1.0f - alpha) * ad + alpha * a_sample;
37
+ od = (1.0f - alpha) * od + alpha * o_sample;
38
+
39
+ active_duty[c] = ad;
40
+ overlap_duty[c] = od;
41
+
42
+ if (learn_flag && boost_strength > 0.0f) {
43
+ boost[c] = expf(-boost_strength * (ad - mean_duty));
44
+ }
45
+ }
overlay/htm_rust/src/gpu/kernels/sp_learn.cu ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SP Hebbian learning kernel.
2
+ //
3
+ // For each active (winner) column c, for each of its synapses s:
4
+ // if input[bit[c][s]] active: perm += inc
5
+ // else: perm -= dec
6
+ // Clamp to [0, 1].
7
+ //
8
+ // Launch: one block per column (2048 blocks), but we predicate on
9
+ // active_mask[c] to avoid launching k-specific blocks.
10
+ //
11
+ // This matches the CPU reference line-for-line:
12
+ // src/sp.rs lines 157-169.
13
+
14
+ extern "C" __global__
15
+ void sp_learn(
16
+ const unsigned char * __restrict__ active_mask, // (n_columns,) 0/1
17
+ const unsigned char * __restrict__ inp, // (input_bits,)
18
+ const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
19
+ float * __restrict__ syn_perm, // (n_columns * S,) in-place
20
+ float inc,
21
+ float dec,
22
+ unsigned int synapses_per_col,
23
+ unsigned int n_columns
24
+ ) {
25
+ const unsigned int c = blockIdx.x;
26
+ if (c >= n_columns) return;
27
+ if (active_mask[c] == 0) return;
28
+
29
+ const unsigned int base = c * synapses_per_col;
30
+ const unsigned int tid = threadIdx.x;
31
+ const unsigned int bsz = blockDim.x;
32
+
33
+ for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
34
+ unsigned int b = syn_bit[base + s];
35
+ float p = syn_perm[base + s];
36
+ if (inp[b] != 0) {
37
+ p += inc;
38
+ if (p > 1.0f) p = 1.0f;
39
+ } else {
40
+ p -= dec;
41
+ if (p < 0.0f) p = 0.0f;
42
+ }
43
+ syn_perm[base + s] = p;
44
+ }
45
+ }
overlay/htm_rust/src/gpu/kernels/sp_overlap.cu ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SP overlap kernel.
2
+ //
3
+ // For each column c (one CUDA block), compute:
4
+ // overlap[c] = sum over its synapse list of {inp[bit[c][s]] && perm[c][s] >= conn_thr}
5
+ // boosted[c] = overlap[c] * boost[c]
6
+ // raw_overlap[c] = overlap[c] (also returned so host can drive duty cycle)
7
+ //
8
+ // Memory layout (flat, column-major with per-column stride = synapses_per_col):
9
+ // syn_bit[c * S + s] : u32 index into input SDR
10
+ // syn_perm[c * S + s] : f32 permanence in [0, 1]
11
+ // boost[c] : f32
12
+ // inp[b] : u8 0/1
13
+ // Output:
14
+ // raw[c] : u32
15
+ // boosted[c] : f32
16
+ //
17
+ // Launch:
18
+ // grid = n_columns
19
+ // block = 128 (or 256) — one warp-sweep across synapses; many warps give
20
+ // parallel reduction across S (typically S=40).
21
+ //
22
+ // At S=40 this is completely latency-bound; we coalesce reads and do a
23
+ // warp-shuffle reduction. For clarity we use a simple block-wide shared-mem
24
+ // reduction which is sufficient for S <= 1024 and has zero correctness risk.
25
+
26
+ extern "C" __global__
27
+ void sp_overlap(
28
+ const unsigned char * __restrict__ inp, // (input_bits,)
29
+ const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
30
+ const float * __restrict__ syn_perm,// (n_columns * S,)
31
+ const float * __restrict__ boost, // (n_columns,)
32
+ float conn_thr,
33
+ unsigned int synapses_per_col, // S
34
+ unsigned int n_columns,
35
+ unsigned int * __restrict__ raw_out, // (n_columns,)
36
+ float * __restrict__ boosted_out // (n_columns,)
37
+ ) {
38
+ const unsigned int c = blockIdx.x;
39
+ if (c >= n_columns) return;
40
+
41
+ const unsigned int base = c * synapses_per_col;
42
+ const unsigned int tid = threadIdx.x;
43
+ const unsigned int bsz = blockDim.x;
44
+
45
+ // Per-thread partial count.
46
+ unsigned int local = 0;
47
+ for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
48
+ unsigned int b = syn_bit[base + s];
49
+ float p = syn_perm[base + s];
50
+ // Branchless: only counts when input active AND perm connected.
51
+ // Using (inp != 0) to tolerate u8 layout.
52
+ unsigned int hit = ((inp[b] != 0) && (p >= conn_thr)) ? 1u : 0u;
53
+ local += hit;
54
+ }
55
+
56
+ // Block-wide reduction in shared memory.
57
+ __shared__ unsigned int smem[32];
58
+
59
+ // Warp-level reduction via shuffle.
60
+ unsigned int lane = tid & 31;
61
+ unsigned int warp = tid >> 5;
62
+ for (int off = 16; off > 0; off >>= 1) {
63
+ local += __shfl_down_sync(0xffffffff, local, off);
64
+ }
65
+ if (lane == 0) smem[warp] = local;
66
+ __syncthreads();
67
+
68
+ if (warp == 0) {
69
+ unsigned int v = (tid < (bsz + 31) / 32) ? smem[lane] : 0;
70
+ for (int off = 16; off > 0; off >>= 1) {
71
+ v += __shfl_down_sync(0xffffffff, v, off);
72
+ }
73
+ if (tid == 0) {
74
+ raw_out[c] = v;
75
+ boosted_out[c] = (float)v * boost[c];
76
+ }
77
+ }
78
+ }
overlay/htm_rust/src/gpu/kernels/sp_topk.cu ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Top-K column selection.
2
+ //
3
+ // Inputs:
4
+ // boosted[n_columns] : f32 score
5
+ // Output:
6
+ // active_mask[n_columns] : u8 0/1, exactly k ones
7
+ //
8
+ // Tie-breaking: when scores are equal, the LOWER column index wins (matches
9
+ // CPU reference `select_nth_unstable_by` with secondary index comparator).
10
+ //
11
+ // Strategy: a single-block implementation. n_columns is typically 2048, which
12
+ // fits comfortably in shared memory. We use a bitonic top-k via per-thread
13
+ // radix-select of the (score, -index) key. At k≈41 of n=2048 the simplest
14
+ // correct approach is a thresholding pass:
15
+ //
16
+ // 1. Radix-like bucket pass to find the k-th largest score.
17
+ // 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
18
+ //
19
+ // For strict index-ordered tie-break we materialise a 64-bit key:
20
+ // key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
21
+ // Larger key = (higher score) OR (same score, smaller index).
22
+ //
23
+ // Then we find the k-th largest 64-bit key via radix-select and mark all
24
+ // columns whose key >= threshold. This is O(n_cols * log k) and well under
25
+ // 100 μs for n=2048, k=41 on sm_86.
26
+ //
27
+ // For simplicity and correctness this kernel uses a single-block parallel
28
+ // selection sort variant (find max → mark → zero → repeat, k iterations).
29
+ // At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
30
+ // fast.
31
+
32
+ extern "C" __global__
33
+ void sp_topk_select(
34
+ const float * __restrict__ scores, // (n_columns,)
35
+ unsigned int n_columns,
36
+ unsigned int k,
37
+ unsigned char * __restrict__ active_out // (n_columns,)
38
+ ) {
39
+ extern __shared__ float smem[];
40
+ // Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
41
+ // smem[n..n+32*2] = reduction scratch (score + index, per warp)
42
+ float * work = smem;
43
+ const unsigned int tid = threadIdx.x;
44
+ const unsigned int bsz = blockDim.x;
45
+
46
+ // Load scores into shared; also init active_out = 0.
47
+ for (unsigned int i = tid; i < n_columns; i += bsz) {
48
+ work[i] = scores[i];
49
+ active_out[i] = 0;
50
+ }
51
+ __syncthreads();
52
+
53
+ __shared__ int winner_idx;
54
+ __shared__ float winner_score;
55
+
56
+ for (unsigned int iter = 0; iter < k; ++iter) {
57
+ // Find (argmax score, lowest index for ties).
58
+ float best_s = -INFINITY;
59
+ int best_i = n_columns; // sentinel larger than any index
60
+
61
+ for (unsigned int i = tid; i < n_columns; i += bsz) {
62
+ float s = work[i];
63
+ if (s > best_s || (s == best_s && (int)i < best_i)) {
64
+ best_s = s;
65
+ best_i = (int)i;
66
+ }
67
+ }
68
+
69
+ // Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
70
+ unsigned int mask = 0xffffffff;
71
+ for (int off = 16; off > 0; off >>= 1) {
72
+ float os = __shfl_down_sync(mask, best_s, off);
73
+ int oi = __shfl_down_sync(mask, best_i, off);
74
+ if (os > best_s || (os == best_s && oi < best_i)) {
75
+ best_s = os;
76
+ best_i = oi;
77
+ }
78
+ }
79
+ // Warp 0 collects lane 0 values from other warps via shared mem.
80
+ __shared__ float warp_s[32];
81
+ __shared__ int warp_i[32];
82
+ unsigned int lane = tid & 31;
83
+ unsigned int warp = tid >> 5;
84
+ if (lane == 0) {
85
+ warp_s[warp] = best_s;
86
+ warp_i[warp] = best_i;
87
+ }
88
+ __syncthreads();
89
+
90
+ if (warp == 0) {
91
+ unsigned int nwarps = (bsz + 31) / 32;
92
+ float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
93
+ int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
94
+ for (int off = 16; off > 0; off >>= 1) {
95
+ float os = __shfl_down_sync(mask, s, off);
96
+ int oi = __shfl_down_sync(mask, i, off);
97
+ if (os > s || (os == s && oi < i)) {
98
+ s = os;
99
+ i = oi;
100
+ }
101
+ }
102
+ if (tid == 0) {
103
+ winner_score = s;
104
+ winner_idx = i;
105
+ }
106
+ }
107
+ __syncthreads();
108
+
109
+ if (tid == 0) {
110
+ if (winner_idx < (int)n_columns) {
111
+ active_out[winner_idx] = 1;
112
+ work[winner_idx] = -INFINITY;
113
+ }
114
+ }
115
+ __syncthreads();
116
+ }
117
+ }
overlay/htm_rust/src/gpu/kernels/tm_activate.cu ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM activate kernel. See tm_predict.cu for TmConfig.
2
+
3
+ struct TmConfig {
4
+ unsigned int activation_threshold;
5
+ unsigned int learning_threshold;
6
+ unsigned int cells_per_column;
7
+ unsigned int synapses_per_segment;
8
+ unsigned int n_segments;
9
+ unsigned int n_cells;
10
+ unsigned int max_segments_per_cell;
11
+ unsigned int max_new_synapses;
12
+ int conn_thr_i16;
13
+ int perm_inc_i16;
14
+ int perm_dec_i16;
15
+ int predicted_seg_dec_i16;
16
+ int initial_perm_i16;
17
+ unsigned int iter_seed;
18
+ unsigned int n_cols;
19
+ unsigned int bits_words;
20
+ };
21
+
22
+ extern "C" __global__
23
+ void tm_activate(
24
+ const unsigned char * __restrict__ sp_active_mask,
25
+ const unsigned char * __restrict__ col_predicted,
26
+ const unsigned int * __restrict__ cell_predictive_bits,
27
+ unsigned int * __restrict__ cell_active_bits,
28
+ unsigned int * __restrict__ cell_winner_bits,
29
+ unsigned int * __restrict__ unpredicted_count,
30
+ unsigned int * __restrict__ burst_cols_flat,
31
+ unsigned int * __restrict__ burst_cols_count,
32
+ TmConfig cfg
33
+ ) {
34
+ unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
35
+ if (col >= cfg.n_cols) return;
36
+ if (sp_active_mask[col] == 0) return;
37
+
38
+ unsigned int base_cell = col * cfg.cells_per_column;
39
+
40
+ if (col_predicted[col]) {
41
+ for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
42
+ unsigned int cell = base_cell + k;
43
+ unsigned int word_idx = cell >> 5;
44
+ unsigned int bit_mask = 1u << (cell & 31u);
45
+ unsigned int pred_word = cell_predictive_bits[word_idx];
46
+ if (pred_word & bit_mask) {
47
+ atomicOr(&cell_active_bits[word_idx], bit_mask);
48
+ atomicOr(&cell_winner_bits[word_idx], bit_mask);
49
+ }
50
+ }
51
+ } else {
52
+ atomicAdd(unpredicted_count, 1u);
53
+ for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
54
+ unsigned int cell = base_cell + k;
55
+ unsigned int word_idx = cell >> 5;
56
+ unsigned int bit_mask = 1u << (cell & 31u);
57
+ atomicOr(&cell_active_bits[word_idx], bit_mask);
58
+ }
59
+ unsigned int winner = base_cell;
60
+ unsigned int word_idx = winner >> 5;
61
+ unsigned int bit_mask = 1u << (winner & 31u);
62
+ atomicOr(&cell_winner_bits[word_idx], bit_mask);
63
+ unsigned int slot = atomicAdd(burst_cols_count, 1u);
64
+ burst_cols_flat[slot] = col;
65
+ }
66
+ }
overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM anomaly kernel.
2
+ //
3
+ // Computes:
4
+ // n_active = sum of sp_active_mask
5
+ // anomaly = unpredicted_count / n_active (if n_active > 0)
6
+ // = 0 (else)
7
+ //
8
+ // Launch: single block, 256 threads.
9
+
10
+ extern "C" __global__
11
+ void tm_anomaly(
12
+ const unsigned char * __restrict__ sp_active_mask,
13
+ const unsigned int * __restrict__ unpredicted_count,
14
+ float * __restrict__ anomaly_out, // (1,) or (t_slot,)
15
+ unsigned int t_slot,
16
+ unsigned int n_cols
17
+ ) {
18
+ const unsigned int tid = threadIdx.x;
19
+ __shared__ unsigned int n_active_s;
20
+
21
+ if (tid == 0) n_active_s = 0u;
22
+ __syncthreads();
23
+
24
+ unsigned int local = 0u;
25
+ for (unsigned int i = tid; i < n_cols; i += blockDim.x) {
26
+ if (sp_active_mask[i]) local += 1u;
27
+ }
28
+ // Warp reduce.
29
+ for (int off = 16; off > 0; off >>= 1) {
30
+ local += __shfl_down_sync(0xffffffffu, local, off);
31
+ }
32
+ if ((tid & 31u) == 0) {
33
+ atomicAdd(&n_active_s, local);
34
+ }
35
+ __syncthreads();
36
+
37
+ if (tid == 0) {
38
+ unsigned int total = n_active_s;
39
+ unsigned int bad = unpredicted_count[0];
40
+ float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
41
+ anomaly_out[t_slot] = anom;
42
+ }
43
+ }
overlay/htm_rust/src/gpu/kernels/tm_grow.cu ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM grow+reinforce kernel.
2
+ //
3
+ // For each bursting column:
4
+ // If col_best_match[col] is non-zero (i.e. at least one matching segment
5
+ // with num_active_potential >= learning_threshold exists on cells in this col):
6
+ // Target = that matching segment.
7
+ // Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
8
+ // Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
9
+ // Else:
10
+ // Allocate a fresh segment slot on winner cell (cell 0 of col).
11
+ // Grow up to max_new synapses to prev_winners (no reinforce needed — new seg).
12
+ //
13
+ // This mirrors the CPU TM burst logic.
14
+
15
+ struct TmConfig {
16
+ unsigned int activation_threshold;
17
+ unsigned int learning_threshold;
18
+ unsigned int cells_per_column;
19
+ unsigned int synapses_per_segment;
20
+ unsigned int n_segments;
21
+ unsigned int n_cells;
22
+ unsigned int max_segments_per_cell;
23
+ unsigned int max_new_synapses;
24
+ int conn_thr_i16;
25
+ int perm_inc_i16;
26
+ int perm_dec_i16;
27
+ int predicted_seg_dec_i16;
28
+ int initial_perm_i16;
29
+ unsigned int iter_seed;
30
+ unsigned int n_cols;
31
+ unsigned int bits_words;
32
+ };
33
+
34
+ extern "C" __global__
35
+ void tm_grow(
36
+ unsigned int * __restrict__ seg_cell_id,
37
+ unsigned int * __restrict__ seg_syn_count,
38
+ unsigned int * __restrict__ syn_presyn,
39
+ short * __restrict__ syn_perm,
40
+ unsigned int * __restrict__ cell_seg_count,
41
+ const unsigned int * __restrict__ burst_cols_flat,
42
+ const unsigned int * __restrict__ burst_cols_count,
43
+ const unsigned int * __restrict__ prev_winner_bits,
44
+ const unsigned int * __restrict__ prev_active_bits,
45
+ const unsigned int * __restrict__ col_best_match,
46
+ TmConfig cfg
47
+ ) {
48
+ const unsigned int b = blockIdx.x;
49
+ const unsigned int n_burst_cols = burst_cols_count[0];
50
+ if (b >= n_burst_cols) return;
51
+ const unsigned int tid = threadIdx.x;
52
+
53
+ const unsigned int col = burst_cols_flat[b];
54
+
55
+ __shared__ unsigned int shared_seg_id;
56
+ __shared__ unsigned int shared_existing_syn_count;
57
+ __shared__ unsigned int shared_grown;
58
+ __shared__ unsigned int shared_is_new;
59
+ __shared__ unsigned int shared_start_offset;
60
+
61
+ if (tid == 0) {
62
+ unsigned int match_key = col_best_match[col];
63
+ if (match_key != 0u) {
64
+ // Reuse matching segment.
65
+ unsigned int seg_id = match_key & 0x1FFFFFu;
66
+ shared_seg_id = seg_id;
67
+ shared_existing_syn_count = seg_syn_count[seg_id];
68
+ shared_is_new = 0u;
69
+ } else {
70
+ // Allocate new segment on winner cell (cell 0 of col).
71
+ unsigned int winner_cell = col * cfg.cells_per_column;
72
+ unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
73
+ if (slot >= cfg.max_segments_per_cell) {
74
+ slot = slot % cfg.max_segments_per_cell;
75
+ }
76
+ unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
77
+ seg_cell_id[seg_id] = winner_cell;
78
+ seg_syn_count[seg_id] = 0;
79
+ shared_seg_id = seg_id;
80
+ shared_existing_syn_count = 0u;
81
+ shared_is_new = 1u;
82
+ }
83
+ shared_grown = 0u;
84
+ shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
85
+ }
86
+ __syncthreads();
87
+
88
+ const unsigned int seg_id = shared_seg_id;
89
+ const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
90
+ const unsigned int existing_syn = shared_existing_syn_count;
91
+ const unsigned int is_new = shared_is_new;
92
+ const unsigned int start = shared_start_offset;
93
+
94
+ // PHASE 1: If reusing, reinforce existing synapses.
95
+ if (!is_new) {
96
+ for (unsigned int s = tid; s < existing_syn; s += 32u) {
97
+ unsigned int presyn = syn_presyn[seg_base + s];
98
+ unsigned int word = prev_active_bits[presyn >> 5];
99
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
100
+ int p = (int)syn_perm[seg_base + s];
101
+ if (bit) {
102
+ int np = p + cfg.perm_inc_i16;
103
+ if (np > 32767) np = 32767;
104
+ syn_perm[seg_base + s] = (short)np;
105
+ } else {
106
+ int np = p - cfg.perm_dec_i16;
107
+ if (np < 0) np = 0;
108
+ syn_perm[seg_base + s] = (short)np;
109
+ }
110
+ }
111
+ __syncthreads();
112
+ }
113
+
114
+ // PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
115
+ // that aren't already presynaptic to this segment.
116
+ const unsigned int room = (cfg.synapses_per_segment > existing_syn)
117
+ ? (cfg.synapses_per_segment - existing_syn) : 0u;
118
+ const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
119
+
120
+ for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
121
+ if (shared_grown >= max_grow) break;
122
+ unsigned int widx = (start + w_off + tid) % cfg.bits_words;
123
+ unsigned int word = prev_winner_bits[widx];
124
+ while (word != 0u) {
125
+ if (shared_grown >= max_grow) break;
126
+ unsigned int bit_pos = __ffs(word) - 1u;
127
+ word &= ~(1u << bit_pos);
128
+ unsigned int cell = widx * 32u + bit_pos;
129
+ if (cell >= cfg.n_cells) continue;
130
+
131
+ // Skip if already presynaptic (O(existing_syn) scan; usually small).
132
+ bool exists = false;
133
+ for (unsigned int s = 0; s < existing_syn; s++) {
134
+ if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
135
+ }
136
+ if (exists) continue;
137
+
138
+ unsigned int slot = atomicAdd(&shared_grown, 1u);
139
+ if (slot >= max_grow) break;
140
+ unsigned int write_idx = existing_syn + slot;
141
+ if (write_idx >= cfg.synapses_per_segment) break;
142
+ syn_presyn[seg_base + write_idx] = cell;
143
+ syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
144
+ }
145
+ }
146
+ __syncthreads();
147
+
148
+ if (tid == 0) {
149
+ unsigned int grown = shared_grown;
150
+ if (grown > max_grow) grown = max_grow;
151
+ unsigned int new_count = existing_syn + grown;
152
+ if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
153
+ seg_syn_count[seg_id] = new_count;
154
+ }
155
+ }
overlay/htm_rust/src/gpu/kernels/tm_learn.cu ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM learn (reinforce correctly predicted segments) — cell-grouped launch.
2
+ //
3
+ // Grid: n_cells.
4
+ // For each cell in a predicted, SP-active column: iterate its segments.
5
+ // For each segment with num_active_connected >= activation_threshold,
6
+ // reinforce its synapses against prev_active_bits.
7
+
8
+ struct TmConfig {
9
+ unsigned int activation_threshold;
10
+ unsigned int learning_threshold;
11
+ unsigned int cells_per_column;
12
+ unsigned int synapses_per_segment;
13
+ unsigned int n_segments;
14
+ unsigned int n_cells;
15
+ unsigned int max_segments_per_cell;
16
+ unsigned int max_new_synapses;
17
+ int conn_thr_i16;
18
+ int perm_inc_i16;
19
+ int perm_dec_i16;
20
+ int predicted_seg_dec_i16;
21
+ int initial_perm_i16;
22
+ unsigned int iter_seed;
23
+ unsigned int n_cols;
24
+ unsigned int bits_words;
25
+ };
26
+
27
+ extern "C" __global__
28
+ void tm_learn_reinforce(
29
+ const unsigned int * __restrict__ seg_cell_id,
30
+ const unsigned int * __restrict__ seg_syn_count,
31
+ const unsigned int * __restrict__ syn_presyn,
32
+ short * __restrict__ syn_perm,
33
+ const unsigned int * __restrict__ seg_num_active_connected,
34
+ const unsigned int * __restrict__ prev_active_bits,
35
+ const unsigned char * __restrict__ sp_active_mask,
36
+ const unsigned char * __restrict__ col_predicted,
37
+ const unsigned int * __restrict__ cell_seg_count,
38
+ TmConfig cfg
39
+ ) {
40
+ const unsigned int cell = blockIdx.x;
41
+ if (cell >= cfg.n_cells) return;
42
+ const unsigned int col = cell / cfg.cells_per_column;
43
+ if (sp_active_mask[col] == 0) return;
44
+ if (col_predicted[col] == 0) return;
45
+
46
+ const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
47
+ if (n_segs_here == 0) return;
48
+
49
+ const unsigned int tid = threadIdx.x;
50
+ const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
51
+
52
+ for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
53
+ const unsigned int seg = seg_base_id + local_seg;
54
+ if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
55
+ const unsigned int n_syn = seg_syn_count[seg];
56
+ if (n_syn == 0) continue;
57
+ const unsigned int syn_base = seg * cfg.synapses_per_segment;
58
+
59
+ for (unsigned int s = tid; s < n_syn; s += 32u) {
60
+ unsigned int presyn = syn_presyn[syn_base + s];
61
+ unsigned int word = prev_active_bits[presyn >> 5];
62
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
63
+ int p = (int)syn_perm[syn_base + s];
64
+ if (bit) {
65
+ int np = p + cfg.perm_inc_i16;
66
+ if (np > 32767) np = 32767;
67
+ syn_perm[syn_base + s] = (short)np;
68
+ } else {
69
+ int np = p - cfg.perm_dec_i16;
70
+ if (np < 0) np = 0;
71
+ syn_perm[syn_base + s] = (short)np;
72
+ }
73
+ }
74
+ }
75
+ }
overlay/htm_rust/src/gpu/kernels/tm_predict.cu ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM predict kernel — cell-grouped launch.
2
+ //
3
+ // Grid: n_cells blocks (one per cell).
4
+ // Block: 32 threads (one warp).
5
+ //
6
+ // Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
7
+ // For each live segment, counts active connected/potential synapses against
8
+ // prev_active_bits. Updates per-segment counters, cell_predictive bit, and
9
+ // col_predicted flag.
10
+
11
+ struct TmConfig {
12
+ unsigned int activation_threshold;
13
+ unsigned int learning_threshold;
14
+ unsigned int cells_per_column;
15
+ unsigned int synapses_per_segment;
16
+ unsigned int n_segments;
17
+ unsigned int n_cells;
18
+ unsigned int max_segments_per_cell;
19
+ unsigned int max_new_synapses;
20
+ int conn_thr_i16;
21
+ int perm_inc_i16;
22
+ int perm_dec_i16;
23
+ int predicted_seg_dec_i16;
24
+ int initial_perm_i16;
25
+ unsigned int iter_seed;
26
+ unsigned int n_cols;
27
+ unsigned int bits_words;
28
+ };
29
+
30
+ extern "C" __global__
31
+ void tm_predict(
32
+ const unsigned int * __restrict__ seg_cell_id,
33
+ const unsigned int * __restrict__ seg_syn_count,
34
+ const unsigned int * __restrict__ syn_presyn,
35
+ const short * __restrict__ syn_perm,
36
+ const unsigned int * __restrict__ cell_active_bits,
37
+ unsigned int * __restrict__ cell_predictive_bits,
38
+ unsigned char * __restrict__ col_predicted,
39
+ unsigned int * __restrict__ seg_num_active_connected,
40
+ unsigned int * __restrict__ seg_num_active_potential,
41
+ unsigned int * __restrict__ col_best_match,
42
+ const unsigned int * __restrict__ cell_seg_count,
43
+ TmConfig cfg
44
+ ) {
45
+ const unsigned int cell = blockIdx.x;
46
+ if (cell >= cfg.n_cells) return;
47
+
48
+ const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
49
+ if (n_segs_here == 0) return;
50
+
51
+ const unsigned int tid = threadIdx.x;
52
+ const unsigned int col = cell / cfg.cells_per_column;
53
+ const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
54
+
55
+ for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
56
+ const unsigned int seg = seg_base_id + local_seg;
57
+ const unsigned int n_syn = seg_syn_count[seg];
58
+ if (n_syn == 0) {
59
+ if (tid == 0) {
60
+ seg_num_active_connected[seg] = 0;
61
+ seg_num_active_potential[seg] = 0;
62
+ }
63
+ continue;
64
+ }
65
+ const unsigned int syn_base = seg * cfg.synapses_per_segment;
66
+
67
+ unsigned int local_conn = 0;
68
+ unsigned int local_pot = 0;
69
+ for (unsigned int s = tid; s < n_syn; s += 32u) {
70
+ unsigned int presyn = syn_presyn[syn_base + s];
71
+ unsigned int word = cell_active_bits[presyn >> 5];
72
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
73
+ if (bit) {
74
+ local_pot += 1u;
75
+ int p = (int)syn_perm[syn_base + s];
76
+ if (p >= cfg.conn_thr_i16) {
77
+ local_conn += 1u;
78
+ }
79
+ }
80
+ }
81
+ for (int off = 16; off > 0; off >>= 1) {
82
+ local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
83
+ local_pot += __shfl_down_sync(0xffffffffu, local_pot, off);
84
+ }
85
+
86
+ if (tid == 0) {
87
+ seg_num_active_connected[seg] = local_conn;
88
+ seg_num_active_potential[seg] = local_pot;
89
+ if (local_conn >= cfg.activation_threshold) {
90
+ unsigned int word_idx = cell >> 5;
91
+ unsigned int bit_mask = 1u << (cell & 31u);
92
+ atomicOr(&cell_predictive_bits[word_idx], bit_mask);
93
+ col_predicted[col] = 1;
94
+ }
95
+ if (local_pot >= cfg.learning_threshold) {
96
+ unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
97
+ unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
98
+ atomicMax(&col_best_match[col], key);
99
+ }
100
+ }
101
+ }
102
+ }
overlay/htm_rust/src/gpu/kernels/tm_punish.cu ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM punish — cell-grouped launch.
2
+
3
+ struct TmConfig {
4
+ unsigned int activation_threshold;
5
+ unsigned int learning_threshold;
6
+ unsigned int cells_per_column;
7
+ unsigned int synapses_per_segment;
8
+ unsigned int n_segments;
9
+ unsigned int n_cells;
10
+ unsigned int max_segments_per_cell;
11
+ unsigned int max_new_synapses;
12
+ int conn_thr_i16;
13
+ int perm_inc_i16;
14
+ int perm_dec_i16;
15
+ int predicted_seg_dec_i16;
16
+ int initial_perm_i16;
17
+ unsigned int iter_seed;
18
+ unsigned int n_cols;
19
+ unsigned int bits_words;
20
+ };
21
+
22
+ extern "C" __global__
23
+ void tm_punish(
24
+ const unsigned int * __restrict__ seg_cell_id,
25
+ const unsigned int * __restrict__ seg_syn_count,
26
+ const unsigned int * __restrict__ syn_presyn,
27
+ short * __restrict__ syn_perm,
28
+ const unsigned int * __restrict__ seg_num_active_potential,
29
+ const unsigned int * __restrict__ prev_active_bits,
30
+ const unsigned char * __restrict__ sp_active_mask,
31
+ const unsigned int * __restrict__ cell_seg_count,
32
+ TmConfig cfg
33
+ ) {
34
+ const unsigned int cell = blockIdx.x;
35
+ if (cell >= cfg.n_cells) return;
36
+ const unsigned int col = cell / cfg.cells_per_column;
37
+ if (sp_active_mask[col] != 0) return; // skip: col became active
38
+
39
+ const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
40
+ if (n_segs_here == 0) return;
41
+
42
+ const unsigned int tid = threadIdx.x;
43
+ const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
44
+
45
+ for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
46
+ const unsigned int seg = seg_base_id + local_seg;
47
+ if (seg_num_active_potential[seg] < cfg.learning_threshold) continue;
48
+ const unsigned int n_syn = seg_syn_count[seg];
49
+ if (n_syn == 0) continue;
50
+ const unsigned int syn_base = seg * cfg.synapses_per_segment;
51
+
52
+ for (unsigned int s = tid; s < n_syn; s += 32u) {
53
+ unsigned int presyn = syn_presyn[syn_base + s];
54
+ unsigned int word = prev_active_bits[presyn >> 5];
55
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
56
+ if (bit) {
57
+ int p = (int)syn_perm[syn_base + s];
58
+ int np = p - cfg.predicted_seg_dec_i16;
59
+ if (np < 0) np = 0;
60
+ syn_perm[syn_base + s] = (short)np;
61
+ }
62
+ }
63
+ }
64
+ }
overlay/htm_rust/src/gpu/kernels/tm_reset.cu ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // TM reset-per-step kernel.
2
+
3
+ extern "C" __global__
4
+ void tm_reset_step(
5
+ unsigned int * __restrict__ cell_active_bits,
6
+ unsigned int * __restrict__ cell_winner_bits,
7
+ unsigned int * __restrict__ cell_predictive_bits,
8
+ unsigned int * __restrict__ prev_active_bits,
9
+ unsigned int * __restrict__ prev_winner_bits,
10
+ unsigned char * __restrict__ col_predicted,
11
+ unsigned int * __restrict__ unpredicted_count,
12
+ unsigned int * __restrict__ burst_cols_count,
13
+ unsigned int * __restrict__ col_best_match,
14
+ unsigned int bits_words,
15
+ unsigned int n_cols
16
+ ) {
17
+ unsigned int tid_global = blockIdx.x * blockDim.x + threadIdx.x;
18
+
19
+ if (tid_global < bits_words) {
20
+ prev_active_bits[tid_global] = cell_active_bits[tid_global];
21
+ prev_winner_bits[tid_global] = cell_winner_bits[tid_global];
22
+ cell_active_bits[tid_global] = 0u;
23
+ cell_winner_bits[tid_global] = 0u;
24
+ cell_predictive_bits[tid_global] = 0u;
25
+ }
26
+
27
+ if (tid_global < n_cols) {
28
+ col_predicted[tid_global] = 0;
29
+ col_best_match[tid_global] = 0u;
30
+ }
31
+
32
+ if (tid_global == 0) {
33
+ unpredicted_count[0] = 0u;
34
+ burst_cols_count[0] = 0u;
35
+ }
36
+ }
overlay/htm_rust/src/gpu/mod.rs ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! GPU backend for HTM.
2
+ //!
3
+ //! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
4
+ //! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
5
+ //! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
6
+ //! to the host in one shot.
7
+ //!
8
+ //! TM parity with the CPU reference is approximate:
9
+ //! - Segment growth: winner = cell 0 of bursting column (CPU picks
10
+ //! least-used-cell with RNG tiebreak). This is a pragmatic simplification
11
+ //! for GPU atomicity; learning dynamics are preserved.
12
+ //! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
13
+ //! f32 by <= 1 ULP of the scale factor (≈ 3e-5) — inside any meaningful
14
+ //! HTM learning quantum.
15
+
16
+ #![cfg(feature = "gpu")]
17
+
18
+ pub mod sp_gpu;
19
+ pub mod tm_gpu;
20
+ pub mod fused;
21
+
22
+ #[cfg(test)]
23
+ mod tests;
24
+
25
+ use std::mem::ManuallyDrop;
26
+
27
+ use pyo3::prelude::*;
28
+ use pyo3::types::{PyDict, PyTuple};
29
+ use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
30
+
31
+ use crate::region::HTMRegionCore;
32
+ use crate::sp::SpatialPoolerConfig;
33
+ use sp_gpu::SpatialPoolerGpu;
34
+ use tm_gpu::TemporalMemoryGpu;
35
+ use fused::FusedState;
36
+
37
+ /// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
38
+ /// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
39
+ /// torch-owned CUDA allocations zero-copy.
40
+ fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
41
+ // `data` is a (ptr: int, readonly: bool) tuple.
42
+ let data_obj = cai.get_item("data")?
43
+ .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
44
+ let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
45
+ .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
46
+ let ptr: u64 = data_tup.get_item(0)?.extract()?;
47
+
48
+ // `shape` is a tuple of ints.
49
+ let shape_obj = cai.get_item("shape")?
50
+ .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
51
+ let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
52
+ .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
53
+ let shape: Vec<usize> = (0..shape_tup.len())
54
+ .map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
55
+ .collect::<PyResult<Vec<_>>>()?;
56
+
57
+ // `typestr` (e.g. "|u1", "<f4").
58
+ let typestr_obj = cai.get_item("typestr")?
59
+ .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
60
+ let typestr: String = typestr_obj.extract()?;
61
+
62
+ // Reject non-contiguous tensors — we don't handle strides.
63
+ if let Some(strides) = cai.get_item("strides")? {
64
+ if !strides.is_none() {
65
+ return Err(pyo3::exceptions::PyValueError::new_err(
66
+ "CAI 'strides' must be None (tensor must be contiguous)",
67
+ ));
68
+ }
69
+ }
70
+
71
+ Ok((ptr, shape, typestr))
72
+ }
73
+
74
+ /// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
75
+ #[pyclass(module = "htm_rust")]
76
+ pub struct HTMRegionGpu {
77
+ pub(super) sp_gpu: SpatialPoolerGpu,
78
+ pub(super) tm_gpu: TemporalMemoryGpu,
79
+ pub(super) fused_state: FusedState,
80
+ pub(super) n_columns: usize,
81
+ pub(super) input_bits: usize,
82
+ pub(super) cells_per_column: usize,
83
+ }
84
+
85
+ #[pymethods]
86
+ impl HTMRegionGpu {
87
+ #[new]
88
+ #[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
89
+ fn new(
90
+ input_bits: usize,
91
+ n_columns: usize,
92
+ cells_per_column: usize,
93
+ seed: u64,
94
+ ) -> PyResult<Self> {
95
+ if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
96
+ return Err(pyo3::exceptions::PyValueError::new_err(
97
+ "input_bits, n_columns, cells_per_column must all be > 0",
98
+ ));
99
+ }
100
+ // CPU reference for deterministic SP init.
101
+ let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
102
+ let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
103
+ let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
104
+ pyo3::exceptions::PyRuntimeError::new_err(format!(
105
+ "GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
106
+ sp_cfg.input_bits, sp_cfg.n_columns,
107
+ ))
108
+ })?;
109
+ let dev = sp_gpu.dev_ref().clone();
110
+ let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
111
+ pyo3::exceptions::PyRuntimeError::new_err(format!(
112
+ "GPU TM init failed: {e:?}",
113
+ ))
114
+ })?;
115
+ let initial_threshold = sp_gpu.initial_threshold_estimate();
116
+ let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
117
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
118
+ "GPU fused state init failed: {e:?}",
119
+ )))?;
120
+ Ok(Self {
121
+ sp_gpu,
122
+ tm_gpu,
123
+ fused_state,
124
+ n_columns,
125
+ input_bits,
126
+ cells_per_column,
127
+ })
128
+ }
129
+
130
+ #[getter] fn input_bits(&self) -> usize { self.input_bits }
131
+ #[getter] fn n_columns(&self) -> usize { self.n_columns }
132
+ #[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
133
+
134
+ /// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
135
+ /// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
136
+ /// to the host at the end.
137
+ #[pyo3(signature = (inputs, learn=true))]
138
+ fn step_many_gpu<'py>(
139
+ &mut self,
140
+ py: Python<'py>,
141
+ inputs: PyReadonlyArray2<'py, bool>,
142
+ learn: bool,
143
+ ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
144
+ let shape = inputs.shape();
145
+ if shape.len() != 2 {
146
+ return Err(pyo3::exceptions::PyValueError::new_err(
147
+ "inputs must be 2-D (T, input_bits)",
148
+ ));
149
+ }
150
+ let t = shape[0];
151
+ let bits = shape[1];
152
+ if bits != self.input_bits {
153
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
154
+ "inputs last dim {bits} != expected input_bits {}",
155
+ self.input_bits,
156
+ )));
157
+ }
158
+ let slice = inputs.as_slice()?;
159
+ let n_cols = self.n_columns;
160
+ let input_vec: Vec<bool> = slice.to_vec();
161
+
162
+ let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
163
+ // 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
164
+ let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
165
+ let inputs_dev = self
166
+ .sp_gpu
167
+ .dev_ref()
168
+ .htod_sync_copy(&sdr_u8_all)
169
+ .map_err(|e| format!("H2D inputs: {e:?}"))?;
170
+
171
+ // 2. Allocate output buffers on device.
172
+ let mut cols_dev = self.sp_gpu.dev_ref()
173
+ .alloc_zeros::<u8>(t * n_cols)
174
+ .map_err(|e| format!("alloc cols: {e:?}"))?;
175
+ let mut anom_dev = self.sp_gpu.dev_ref()
176
+ .alloc_zeros::<f32>(t)
177
+ .map_err(|e| format!("alloc anom: {e:?}"))?;
178
+
179
+ // 3. Run T steps of SP + TM on GPU with NO per-step host sync.
180
+ self.sp_gpu.step_batch_with_tm(
181
+ &inputs_dev,
182
+ t,
183
+ self.input_bits,
184
+ learn,
185
+ &mut cols_dev,
186
+ &mut anom_dev,
187
+ &mut self.tm_gpu,
188
+ ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
189
+
190
+ // 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
191
+ let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
192
+ .dtoh_sync_copy(&cols_dev)
193
+ .map_err(|e| format!("D2H cols: {e:?}"))?;
194
+ let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
195
+ .dtoh_sync_copy(&anom_dev)
196
+ .map_err(|e| format!("D2H anom: {e:?}"))?;
197
+
198
+ Ok((cols_host, anom_host))
199
+ });
200
+
201
+ let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
202
+
203
+ let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
204
+ let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
205
+ .reshape([t, n_cols])
206
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
207
+ let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
208
+ Ok((cols_arr, anom_arr))
209
+ }
210
+
211
+ /// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
212
+ /// write outputs directly into caller-allocated torch tensors. Skips the
213
+ /// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
214
+ /// two D2H copies at the end). This is the hot path for `train.py`.
215
+ ///
216
+ /// Contract:
217
+ /// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
218
+ /// cols_cai.shape == (T, n_columns), dtype u8 (written)
219
+ /// anom_cai.shape == (T,), dtype f32 (written)
220
+ /// All three tensors must live on the SAME CUDA device as this region.
221
+ ///
222
+ /// The torch tensors still own their memory — this method only wraps
223
+ /// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
224
+ /// impl can't free pytorch's allocator.
225
+ #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
226
+ fn step_many_cuda(
227
+ &mut self,
228
+ py: Python<'_>,
229
+ sdr_cai: &Bound<'_, PyDict>,
230
+ cols_cai: &Bound<'_, PyDict>,
231
+ anom_cai: &Bound<'_, PyDict>,
232
+ learn: bool,
233
+ ) -> PyResult<()> {
234
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
235
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
236
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
237
+
238
+ // typestr sanity. numpy u1 is what torch.uint8 exports.
239
+ if sdr_type != "|u1" {
240
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
241
+ "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
242
+ )));
243
+ }
244
+ if cols_type != "|u1" {
245
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
246
+ "cols_cai typestr must be '|u1' (uint8), got {cols_type}",
247
+ )));
248
+ }
249
+ if anom_type != "<f4" && anom_type != "=f4" {
250
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
251
+ "anom_cai typestr must be '<f4' (float32), got {anom_type}",
252
+ )));
253
+ }
254
+
255
+ // Shape validation.
256
+ if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
257
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
258
+ "sdr_cai shape {sdr_shape:?} != (T, {})",
259
+ self.input_bits,
260
+ )));
261
+ }
262
+ let t = sdr_shape[0];
263
+ if cols_shape != [t, self.n_columns] {
264
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
265
+ "cols_cai shape {cols_shape:?} != ({t}, {})",
266
+ self.n_columns,
267
+ )));
268
+ }
269
+ if anom_shape != [t] {
270
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
271
+ "anom_cai shape {anom_shape:?} != ({t},)",
272
+ )));
273
+ }
274
+
275
+ let dev = self.sp_gpu.dev_ref().clone();
276
+ let n_cols = self.n_columns;
277
+ let input_bits = self.input_bits;
278
+
279
+ let result = py.allow_threads(|| -> Result<(), String> {
280
+ // SAFETY:
281
+ // - ptrs came from torch CUDA tensors validated non-null by the
282
+ // __cuda_array_interface__ contract.
283
+ // - lens computed from validated shapes.
284
+ // - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
285
+ // Drop (which calls cuMemFree) never runs against torch memory.
286
+ // The underlying allocation is owned+freed by torch.
287
+ // - The slices are used only for the duration of this call;
288
+ // torch guarantees the backing tensors are live across it
289
+ // (Python holds refs on the wrapping tensors).
290
+ let inputs_dev = ManuallyDrop::new(unsafe {
291
+ dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
292
+ });
293
+ let mut cols_dev = ManuallyDrop::new(unsafe {
294
+ dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
295
+ });
296
+ let mut anom_dev = ManuallyDrop::new(unsafe {
297
+ dev.upgrade_device_ptr::<f32>(anom_ptr, t)
298
+ });
299
+
300
+ self.sp_gpu.step_batch_with_tm(
301
+ &inputs_dev,
302
+ t,
303
+ input_bits,
304
+ learn,
305
+ &mut cols_dev,
306
+ &mut anom_dev,
307
+ &mut self.tm_gpu,
308
+ ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
309
+
310
+ // Synchronize: kernel writes must be visible to the next torch
311
+ // op that reads cols/anom. Pytorch's default stream is stream 0,
312
+ // and cudarc launches on its own stream — a full device sync
313
+ // is the simplest correct barrier. (Could narrow to a stream
314
+ // wait event in PR 2.)
315
+ // No dev.synchronize() here: caller must explicitly sync via the
316
+ // `device_sync()` method (or PyTorch auto-syncs when the output
317
+ // tensor is next consumed). Removing the per-launch barrier lets
318
+ // subsequent GPU work (mamba3 fwd, etc.) overlap in time.
319
+ Ok(())
320
+ });
321
+
322
+ result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
323
+ Ok(())
324
+ }
325
+
326
+ /// Clear TM state on the GPU.
327
+ fn reset(&mut self) -> PyResult<()> {
328
+ self.tm_gpu.reset().map_err(|e| {
329
+ pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
330
+ })?;
331
+ self.fused_state.reset().map_err(|e| {
332
+ pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
333
+ })
334
+ }
335
+
336
+ /// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
337
+ /// forward (SP + TM all in one). Accepts torch CUDA tensors via
338
+ /// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
339
+ /// anomaly directly into caller-allocated torch tensors.
340
+ ///
341
+ /// Semantics diverge from `step_many_cuda` in one important way: column
342
+ /// activation uses per-column threshold inhibition instead of global
343
+ /// top-K. The threshold is EMA-adapted per column toward the sparsity
344
+ /// target. See `docs/GPU_HTM.md` §Fused Kernel.
345
+ #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
346
+ fn step_many_fused_cuda(
347
+ &mut self,
348
+ py: Python<'_>,
349
+ sdr_cai: &Bound<'_, PyDict>,
350
+ cols_cai: &Bound<'_, PyDict>,
351
+ anom_cai: &Bound<'_, PyDict>,
352
+ learn: bool,
353
+ ) -> PyResult<()> {
354
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
355
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
356
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
357
+
358
+ if sdr_type != "|u1" {
359
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
360
+ "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
361
+ )));
362
+ }
363
+ if cols_type != "|u1" {
364
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
365
+ "cols_cai typestr must be '|u1' (uint8), got {cols_type}",
366
+ )));
367
+ }
368
+ if anom_type != "<f4" && anom_type != "=f4" {
369
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
370
+ "anom_cai typestr must be '<f4' (float32), got {anom_type}",
371
+ )));
372
+ }
373
+
374
+ if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
375
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
376
+ "sdr_cai shape {sdr_shape:?} != (T, {})",
377
+ self.input_bits,
378
+ )));
379
+ }
380
+ let t = sdr_shape[0];
381
+ if cols_shape != [t, self.n_columns] {
382
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
383
+ "cols_cai shape {cols_shape:?} != ({t}, {})",
384
+ self.n_columns,
385
+ )));
386
+ }
387
+ if anom_shape != [t] {
388
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
389
+ "anom_cai shape {anom_shape:?} != ({t},)",
390
+ )));
391
+ }
392
+
393
+ let dev = self.sp_gpu.dev_ref().clone();
394
+ let n_cols = self.n_columns;
395
+ let input_bits = self.input_bits;
396
+
397
+ let result = py.allow_threads(|| -> Result<(), String> {
398
+ let inputs_dev = ManuallyDrop::new(unsafe {
399
+ dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
400
+ });
401
+ let mut cols_dev = ManuallyDrop::new(unsafe {
402
+ dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
403
+ });
404
+ let mut anom_dev = ManuallyDrop::new(unsafe {
405
+ dev.upgrade_device_ptr::<f32>(anom_ptr, t)
406
+ });
407
+
408
+ fused::launch_fused(
409
+ &mut self.sp_gpu,
410
+ &mut self.tm_gpu,
411
+ &mut self.fused_state,
412
+ &inputs_dev,
413
+ &mut cols_dev,
414
+ &mut anom_dev,
415
+ t,
416
+ input_bits,
417
+ learn,
418
+ ).map_err(|e| format!("launch_fused: {e:?}"))?;
419
+
420
+ // No dev.synchronize() here: caller must explicitly sync via the
421
+ // `device_sync()` method (or PyTorch auto-syncs when the output
422
+ // tensor is next consumed). Removing the per-launch barrier lets
423
+ // subsequent GPU work (mamba3 fwd, etc.) overlap in time.
424
+ Ok(())
425
+ });
426
+
427
+ result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
428
+ Ok(())
429
+ }
430
+
431
+ /// Explicit device synchronization — the caller must invoke this after
432
+ /// all batched `step_many_*_cuda` calls complete, before reading the
433
+ /// output tensors from a different CUDA stream. Equivalent to the old
434
+ /// per-call `dev.synchronize()` that was removed for overlap.
435
+ fn device_sync(&self) -> PyResult<()> {
436
+ let dev = self.sp_gpu.dev_ref();
437
+ dev.synchronize()
438
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
439
+ Ok(())
440
+ }
441
+ }
442
+
443
+ /// Batch B regions into ONE cooperative kernel launch. Breaks through the
444
+ /// CUDA cooperative-kernel device-level serialization: a single cooperative
445
+ /// launch with grid.y=B processes all regions concurrently — ~B× speedup
446
+ /// over B sequential launches.
447
+ ///
448
+ /// All regions must have the same config (input_bits, n_columns,
449
+ /// cells_per_column). Each region keeps its independent GPU state.
450
+ /// Does NOT sync; caller must invoke `device_sync()` on any region
451
+ /// afterwards (or rely on a downstream torch op to auto-sync).
452
+ #[pyfunction]
453
+ #[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
454
+ fn step_batch_fused_cuda(
455
+ py: Python<'_>,
456
+ regions: Vec<Py<HTMRegionGpu>>,
457
+ sdr_cais: Vec<Bound<'_, PyDict>>,
458
+ cols_cais: Vec<Bound<'_, PyDict>>,
459
+ anom_cais: Vec<Bound<'_, PyDict>>,
460
+ learn: bool,
461
+ ) -> PyResult<()> {
462
+ let b = regions.len();
463
+ if b == 0 {
464
+ return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
465
+ }
466
+ if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
467
+ return Err(pyo3::exceptions::PyValueError::new_err(
468
+ "sdr_cais / cols_cais / anom_cais length must match regions",
469
+ ));
470
+ }
471
+
472
+ // Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
473
+ let mut sdr_ptrs = Vec::with_capacity(b);
474
+ let mut cols_ptrs = Vec::with_capacity(b);
475
+ let mut anom_ptrs = Vec::with_capacity(b);
476
+ let (input_bits, n_columns, t) = {
477
+ let r0 = regions[0].bind(py).borrow();
478
+ (r0.input_bits, r0.n_columns, {
479
+ let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
480
+ if sh.len() != 2 {
481
+ return Err(pyo3::exceptions::PyValueError::new_err(
482
+ format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
483
+ ));
484
+ }
485
+ sh[0]
486
+ })
487
+ };
488
+
489
+ for i in 0..b {
490
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
491
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
492
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
493
+ if sdr_type != "|u1" || cols_type != "|u1" {
494
+ return Err(pyo3::exceptions::PyValueError::new_err(
495
+ "sdr/cols typestr must be '|u1' (uint8)",
496
+ ));
497
+ }
498
+ if anom_type != "<f4" && anom_type != "=f4" {
499
+ return Err(pyo3::exceptions::PyValueError::new_err(
500
+ "anom typestr must be '<f4' (float32)",
501
+ ));
502
+ }
503
+ if sdr_shape != [t, input_bits] {
504
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
505
+ "sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
506
+ )));
507
+ }
508
+ if cols_shape != [t, n_columns] {
509
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
510
+ "cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
511
+ )));
512
+ }
513
+ if anom_shape != [t] {
514
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
515
+ "anom[{i}] shape {anom_shape:?} != ({t},)"
516
+ )));
517
+ }
518
+ sdr_ptrs.push(sdr_ptr);
519
+ cols_ptrs.push(cols_ptr);
520
+ anom_ptrs.push(anom_ptr);
521
+ }
522
+
523
+ // Exclusively borrow each region. PyRefMut guarantees uniqueness.
524
+ let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
525
+ regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
526
+ // Collect raw mutable pointers — each PyRefMut exclusively borrows its
527
+ // region for the lifetime of this call, so pointers stay valid and
528
+ // unique. launch_fused_batched_raw only dereferences one region at a
529
+ // time, not constructing an aliased slice.
530
+ let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
531
+ .iter_mut()
532
+ .map(|r| &mut **r as *mut HTMRegionGpu)
533
+ .collect();
534
+
535
+ // No allow_threads: raw pointers aren't Send. The launch is GPU-queued
536
+ // and sync'd downstream; holding the GIL for the duration is cheap.
537
+ fused::launch_fused_batched_raw(
538
+ &raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
539
+ t, input_bits, learn,
540
+ )
541
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
542
+ Ok(())
543
+ }
544
+
545
+ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
546
+ m.add_class::<HTMRegionGpu>()?;
547
+ m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
548
+ Ok(())
549
+ }
overlay/htm_rust/src/gpu/sp_gpu.rs ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! GPU implementation of the Spatial Pooler.
2
+ //!
3
+ //! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX
4
+ //! kernels. `compute(input, learn)` performs one SP step and returns the
5
+ //! sorted active-column indices (host `Vec<u32>`) — this is what the CPU
6
+ //! TemporalMemory consumes.
7
+ //!
8
+ //! Persistent state on device (per region):
9
+ //! syn_bit : u32 [n_columns × S] (constant after init)
10
+ //! syn_perm : f32 [n_columns × S] (updated by sp_learn)
11
+ //! boost : f32 [n_columns]
12
+ //! active_duty : f32 [n_columns]
13
+ //! overlap_duty: f32 [n_columns]
14
+ //!
15
+ //! Per-step transient state:
16
+ //! inp_dev : u8 [input_bits] (H2D copy each step)
17
+ //! raw : u32 [n_columns]
18
+ //! boosted : f32 [n_columns]
19
+ //! active_mask : u8 [n_columns] (topk output, D2H at the end)
20
+
21
+ use std::sync::Arc;
22
+
23
+ use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig};
24
+ use cudarc::nvrtc::Ptx;
25
+
26
+ use crate::sp::SpatialPooler;
27
+
28
+ // Embed PTX at compile time. OUT_DIR is set by build.rs.
29
+ const PTX_SP_OVERLAP: &str =
30
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx"));
31
+ const PTX_SP_TOPK: &str =
32
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx"));
33
+ const PTX_SP_LEARN: &str =
34
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx"));
35
+ const PTX_SP_DUTY: &str =
36
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx"));
37
+ const PTX_SP_BOOST_FUSED: &str =
38
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx"));
39
+
40
+ pub struct SpatialPoolerGpu {
41
+ dev: Arc<CudaDevice>,
42
+
43
+ // Config mirror (we don't touch CPU SpatialPooler after init).
44
+ input_bits: usize,
45
+ n_columns: usize,
46
+ synapses_per_col: usize,
47
+ conn_thr: f32,
48
+ inc: f32,
49
+ dec: f32,
50
+ sparsity: f32,
51
+ duty_period: f32,
52
+ boost_strength: f32,
53
+
54
+ // Persistent device state.
55
+ syn_bit: CudaSlice<u32>,
56
+ syn_perm: CudaSlice<f32>,
57
+ boost: CudaSlice<f32>,
58
+ active_duty: CudaSlice<f32>,
59
+ overlap_duty: CudaSlice<f32>,
60
+
61
+ // Transient scratch (reused each step).
62
+ inp_dev: CudaSlice<u8>,
63
+ raw: CudaSlice<u32>,
64
+ boosted: CudaSlice<f32>,
65
+ active_mask: CudaSlice<u8>,
66
+
67
+ // Reusable host buffer for D2H of active_mask.
68
+ host_mask: Vec<u8>,
69
+
70
+ /// Strict bit-parity with CPU reference. Enabled for tests.
71
+ /// Forces host-side boost/exp computation and the overlap-duty bump check
72
+ /// every step. Default false for max throughput.
73
+ strict_parity: bool,
74
+ }
75
+
76
+ impl SpatialPoolerGpu {
77
+ /// Copy CPU SpatialPooler state onto the device. This preserves the
78
+ /// exact seeded proximal synapse layout + initial permanences, so the
79
+ /// GPU SP is a bit-identical parallel implementation of the CPU SP.
80
+ pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> {
81
+ let dev = CudaDevice::new(0)?;
82
+ let cfg = &cpu.cfg;
83
+ let n = cfg.n_columns;
84
+ let s = cfg.potential_synapses;
85
+
86
+ // Flatten proximal dendrites into column-major arrays.
87
+ let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s);
88
+ let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s);
89
+ for col in &cpu.columns {
90
+ debug_assert_eq!(col.inputs.len(), s);
91
+ debug_assert_eq!(col.perms.len(), s);
92
+ syn_bit_h.extend_from_slice(&col.inputs);
93
+ syn_perm_h.extend_from_slice(&col.perms);
94
+ }
95
+
96
+ let syn_bit = dev.htod_sync_copy(&syn_bit_h)?;
97
+ let syn_perm = dev.htod_sync_copy(&syn_perm_h)?;
98
+ let boost = dev.htod_sync_copy(&cpu.boost)?;
99
+ let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?;
100
+ let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?;
101
+
102
+ let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?;
103
+ let raw: CudaSlice<u32> = dev.alloc_zeros(n)?;
104
+ let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?;
105
+ let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?;
106
+
107
+ // Load PTX modules. Each .ptx is a module containing one `extern "C"`
108
+ // function; we tag them by unique module names so multiple SP instances
109
+ // don't collide (cudarc uses the (module, func) pair).
110
+ // Actually: CudaDevice::load_ptx stores under the given module name
111
+ // globally on the device, so we use a deterministic naming scheme.
112
+ let modules = [
113
+ ("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"),
114
+ ("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"),
115
+ ("htm_sp_learn", PTX_SP_LEARN, "sp_learn"),
116
+ ("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"),
117
+ ("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"),
118
+ ];
119
+ for (modname, ptx, fnname) in modules {
120
+ // load_ptx is NOT idempotent — calling twice errors. For multi-region
121
+ // support we check-then-load.
122
+ if dev.get_func(modname, fnname).is_none() {
123
+ dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
124
+ }
125
+ }
126
+
127
+ Ok(Self {
128
+ dev,
129
+ input_bits: cfg.input_bits,
130
+ n_columns: n,
131
+ synapses_per_col: s,
132
+ conn_thr: cfg.connected_threshold,
133
+ inc: cfg.syn_perm_active_inc,
134
+ dec: cfg.syn_perm_inactive_dec,
135
+ sparsity: cfg.sparsity,
136
+ duty_period: cfg.duty_cycle_period,
137
+ boost_strength: cfg.boost_strength,
138
+ syn_bit,
139
+ syn_perm,
140
+ boost,
141
+ active_duty,
142
+ overlap_duty,
143
+ inp_dev,
144
+ raw,
145
+ boosted,
146
+ active_mask,
147
+ host_mask: vec![0u8; n],
148
+ strict_parity: false,
149
+ })
150
+ }
151
+
152
+ /// Enable strict bit-parity mode. Parity tests use this.
153
+ pub fn set_strict_parity(&mut self, strict: bool) {
154
+ self.strict_parity = strict;
155
+ }
156
+
157
+ /// Access to the underlying CudaDevice for host-side orchestration.
158
+ pub fn dev_ref(&self) -> &Arc<CudaDevice> {
159
+ &self.dev
160
+ }
161
+
162
+ // --- Fused-path accessors (immutable state reads + pointer-grabs). ---
163
+ pub fn n_columns_accessor(&self) -> usize { self.n_columns }
164
+ #[allow(dead_code)]
165
+ pub fn input_bits_accessor(&self) -> usize { self.input_bits }
166
+ pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col }
167
+ pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr }
168
+ pub fn inc_accessor(&self) -> f32 { self.inc }
169
+ pub fn dec_accessor(&self) -> f32 { self.dec }
170
+ pub fn sparsity_accessor(&self) -> f32 { self.sparsity }
171
+ pub fn duty_period_accessor(&self) -> f32 { self.duty_period }
172
+ #[allow(dead_code)]
173
+ pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength }
174
+
175
+ pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit }
176
+ pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm }
177
+ pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost }
178
+ pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty }
179
+
180
+ /// Compute the 95th-percentile-like initial threshold from raw overlaps
181
+ /// after a short warmup pass. Used to seed `inhibition_threshold` such
182
+ /// that activation rate starts near the sparsity target.
183
+ /// Placeholder (returns a conservative constant); real warmup pass
184
+ /// happens on the Rust orchestrator side.
185
+ pub fn initial_threshold_estimate(&self) -> f32 {
186
+ // With conn_thr=0.5, init_perm around 0.5±0.1, S=40, sparse SDR at 2%:
187
+ // expected overlap ~ 40 * 0.02 = 0.8 connected hits → boosted ~ 0.8.
188
+ // Top-K selects top 2%, so threshold for top 2% is roughly the
189
+ // 98th-percentile of boosted. Conservative start: 2.0.
190
+ // The per-column adaptation will quickly steer each column's thr.
191
+ 2.0f32
192
+ }
193
+
194
+ /// Batched multi-step SP on the GPU. Processes T timesteps from a
195
+ /// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column
196
+ /// mask to `cols_dev_out` and `(T,)` active column index list (in a
197
+ /// per-step window of size k, padded with u32::MAX).
198
+ ///
199
+ /// For each step, this runs the same 5-kernel pipeline as `compute`, but
200
+ /// skips the per-step boost/duty D2H→exp→H2D round-trip: instead it
201
+ /// accumulates to a host scratch once every `boost_interval` steps.
202
+ ///
203
+ /// This is the fast path used by `HTMRegionGpu.step_many_gpu`.
204
+ #[allow(clippy::too_many_arguments)]
205
+ pub fn step_batch(
206
+ &mut self,
207
+ inputs_flat_dev: &CudaSlice<u8>,
208
+ t: usize,
209
+ input_bits: usize,
210
+ learn: bool,
211
+ cols_out: &mut [u8],
212
+ active_indices_host: &mut Vec<u32>,
213
+ ) -> Result<(), DriverError> {
214
+ let n = self.n_columns;
215
+ let k = ((self.sparsity * n as f32).round() as usize).max(1);
216
+ debug_assert_eq!(cols_out.len(), t * n);
217
+
218
+ let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
219
+ let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
220
+ let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
221
+ let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
222
+
223
+ let overlap_cfg = LaunchConfig {
224
+ grid_dim: (n as u32, 1, 1),
225
+ block_dim: (128, 1, 1),
226
+ shared_mem_bytes: 0,
227
+ };
228
+ let topk_cfg = LaunchConfig {
229
+ grid_dim: (1, 1, 1),
230
+ block_dim: (256, 1, 1),
231
+ shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
232
+ };
233
+ let learn_cfg = overlap_cfg;
234
+ let duty_cfg = LaunchConfig {
235
+ grid_dim: ((n as u32 + 255) / 256, 1, 1),
236
+ block_dim: (256, 1, 1),
237
+ shared_mem_bytes: 0,
238
+ };
239
+ let alpha = 1.0f32 / self.duty_period.max(1.0);
240
+
241
+ // Reusable host buffer for the per-step active_mask D2H.
242
+ self.host_mask.resize(n, 0);
243
+
244
+ active_indices_host.clear();
245
+
246
+ for ti in 0..t {
247
+ // Point overlap kernel at the ti-th slice of the pre-uploaded input.
248
+ // cudarc CudaSlice doesn't have a "view" per se, so we must copy the
249
+ // slice into the reusable inp_dev buffer. This is a D2D copy — much
250
+ // faster than H2D.
251
+ // (Alternative: rewrite kernel to accept an offset; deferred.)
252
+ let in_off = ti * input_bits;
253
+ // Use dtod_copy via raw slice indexing: cudarc exposes slice() for this.
254
+ let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
255
+ self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
256
+
257
+ // 1. sp_overlap
258
+ unsafe {
259
+ overlap_fn.clone().launch(
260
+ overlap_cfg,
261
+ (
262
+ &self.inp_dev,
263
+ &self.syn_bit,
264
+ &self.syn_perm,
265
+ &self.boost,
266
+ self.conn_thr,
267
+ self.synapses_per_col as u32,
268
+ n as u32,
269
+ &mut self.raw,
270
+ &mut self.boosted,
271
+ ),
272
+ )?;
273
+ }
274
+
275
+ // 2. Clear active_mask, then sp_topk
276
+ self.dev.memset_zeros(&mut self.active_mask)?;
277
+ unsafe {
278
+ topk_fn.clone().launch(
279
+ topk_cfg,
280
+ (&self.boosted, n as u32, k as u32, &mut self.active_mask),
281
+ )?;
282
+ }
283
+
284
+ // 3. sp_learn
285
+ if learn {
286
+ unsafe {
287
+ learn_fn.clone().launch(
288
+ learn_cfg,
289
+ (
290
+ &self.active_mask,
291
+ &self.inp_dev,
292
+ &self.syn_bit,
293
+ &mut self.syn_perm,
294
+ self.inc,
295
+ self.dec,
296
+ self.synapses_per_col as u32,
297
+ n as u32,
298
+ ),
299
+ )?;
300
+ }
301
+ }
302
+
303
+ // 4. duty update (device)
304
+ unsafe {
305
+ duty_fn.clone().launch(
306
+ duty_cfg,
307
+ (
308
+ &self.active_mask,
309
+ &self.raw,
310
+ &mut self.active_duty,
311
+ &mut self.overlap_duty,
312
+ &mut self.boost,
313
+ alpha,
314
+ 1.0f32,
315
+ 0.0f32,
316
+ 0.0f32,
317
+ 0u32,
318
+ n as u32,
319
+ ),
320
+ )?;
321
+ }
322
+
323
+ // 5. Boost update. Two modes:
324
+ // * strict_parity (tests): host-side exp for bit-exact match.
325
+ // * default (production): GPU expf is close enough and ~10x faster
326
+ // since we skip the D2H/H2D round-trip.
327
+ if learn && self.boost_strength > 0.0 {
328
+ if self.strict_parity {
329
+ let mut duty_host = vec![0f32; n];
330
+ self.dev
331
+ .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
332
+ let sum: f32 = duty_host.iter().sum();
333
+ let mean = sum / (n as f32);
334
+ let mut boost_host = vec![0f32; n];
335
+ for i in 0..n {
336
+ boost_host[i] =
337
+ (-self.boost_strength * (duty_host[i] - mean)).exp();
338
+ }
339
+ self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
340
+
341
+ // Permanence bump (rare). Only evaluated in strict mode.
342
+ let mut ov_host = vec![0f32; n];
343
+ self.dev
344
+ .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
345
+ let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
346
+ if max_ov > 0.0 {
347
+ let thr = 0.001f32 * max_ov;
348
+ let bump = self.inc * 0.1f32;
349
+ let bump_cols: Vec<u32> = ov_host
350
+ .iter()
351
+ .enumerate()
352
+ .filter_map(|(i, &o)| {
353
+ if o < thr { Some(i as u32) } else { None }
354
+ })
355
+ .collect();
356
+ if !bump_cols.is_empty() {
357
+ let s = self.synapses_per_col;
358
+ let mut perm_host = vec![0f32; n * s];
359
+ self.dev
360
+ .dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
361
+ for &c in &bump_cols {
362
+ let base = (c as usize) * s;
363
+ for p in &mut perm_host[base..base + s] {
364
+ *p = (*p + bump).min(1.0);
365
+ }
366
+ }
367
+ self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
368
+ }
369
+ }
370
+ } else {
371
+ // Fast path: fused mean + boost = expf(-strength*(ad-mean))
372
+ // in a single GPU block. Zero D2H, zero H2D — fully async.
373
+ let boost_fn = self
374
+ .dev
375
+ .get_func("htm_sp_boost_fused", "sp_boost_from_duty")
376
+ .expect("sp_boost_fused not loaded");
377
+ let boost_cfg = LaunchConfig {
378
+ grid_dim: (1, 1, 1),
379
+ block_dim: (1024, 1, 1),
380
+ shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
381
+ };
382
+ unsafe {
383
+ boost_fn.launch(
384
+ boost_cfg,
385
+ (
386
+ &self.active_duty,
387
+ &mut self.boost,
388
+ self.boost_strength,
389
+ n as u32,
390
+ ),
391
+ )?;
392
+ }
393
+ }
394
+ }
395
+
396
+ // D2H the active_mask for this step. This is the single
397
+ // unavoidable sync point per step — CPU TM needs the active
398
+ // indices for its next state update. At 2048 bytes / step this
399
+ // is tiny in bandwidth but costs a full syncronize (~5-10μs).
400
+ self.dev
401
+ .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
402
+ let co = ti * n;
403
+ cols_out[co..co + n].copy_from_slice(&self.host_mask);
404
+ // Extract active indices.
405
+ for (i, &b) in self.host_mask.iter().enumerate() {
406
+ if b != 0 {
407
+ active_indices_host.push(i as u32);
408
+ }
409
+ }
410
+ // Insert separator (u32::MAX) between steps to demarcate step boundaries.
411
+ active_indices_host.push(u32::MAX);
412
+ }
413
+
414
+ Ok(())
415
+ }
416
+
417
+ /// Fully-on-GPU batched SP + TM. Zero per-step host sync.
418
+ ///
419
+ /// Inputs:
420
+ /// inputs_flat_dev : (T * input_bits) u8 already uploaded
421
+ /// cols_dev : (T * n_cols) u8 output — active-column mask per step
422
+ /// anom_dev : (T,) f32 output — anomaly score per step
423
+ /// tm : persistent GPU TemporalMemory for this region
424
+ #[allow(clippy::too_many_arguments)]
425
+ pub fn step_batch_with_tm(
426
+ &mut self,
427
+ inputs_flat_dev: &CudaSlice<u8>,
428
+ t: usize,
429
+ input_bits: usize,
430
+ learn: bool,
431
+ cols_dev: &mut CudaSlice<u8>,
432
+ anom_dev: &mut CudaSlice<f32>,
433
+ tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu,
434
+ ) -> Result<(), DriverError> {
435
+ let n = self.n_columns;
436
+ let k = ((self.sparsity * n as f32).round() as usize).max(1);
437
+ debug_assert_eq!(cols_dev.len(), t * n);
438
+ debug_assert_eq!(anom_dev.len(), t);
439
+
440
+ let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
441
+ let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
442
+ let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
443
+ let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
444
+
445
+ let overlap_cfg = LaunchConfig {
446
+ grid_dim: (n as u32, 1, 1),
447
+ block_dim: (128, 1, 1),
448
+ shared_mem_bytes: 0,
449
+ };
450
+ let topk_cfg = LaunchConfig {
451
+ grid_dim: (1, 1, 1),
452
+ block_dim: (256, 1, 1),
453
+ shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
454
+ };
455
+ let learn_cfg = overlap_cfg;
456
+ let duty_cfg = LaunchConfig {
457
+ grid_dim: ((n as u32 + 255) / 256, 1, 1),
458
+ block_dim: (256, 1, 1),
459
+ shared_mem_bytes: 0,
460
+ };
461
+ let alpha = 1.0f32 / self.duty_period.max(1.0);
462
+
463
+ for ti in 0..t {
464
+ let in_off = ti * input_bits;
465
+ let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
466
+ self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
467
+
468
+ // 1. sp_overlap
469
+ unsafe {
470
+ overlap_fn.clone().launch(
471
+ overlap_cfg,
472
+ (
473
+ &self.inp_dev,
474
+ &self.syn_bit,
475
+ &self.syn_perm,
476
+ &self.boost,
477
+ self.conn_thr,
478
+ self.synapses_per_col as u32,
479
+ n as u32,
480
+ &mut self.raw,
481
+ &mut self.boosted,
482
+ ),
483
+ )?;
484
+ }
485
+
486
+ // 2. clear + sp_topk
487
+ self.dev.memset_zeros(&mut self.active_mask)?;
488
+ unsafe {
489
+ topk_fn.clone().launch(
490
+ topk_cfg,
491
+ (&self.boosted, n as u32, k as u32, &mut self.active_mask),
492
+ )?;
493
+ }
494
+
495
+ // 3. sp_learn
496
+ if learn {
497
+ unsafe {
498
+ learn_fn.clone().launch(
499
+ learn_cfg,
500
+ (
501
+ &self.active_mask,
502
+ &self.inp_dev,
503
+ &self.syn_bit,
504
+ &mut self.syn_perm,
505
+ self.inc,
506
+ self.dec,
507
+ self.synapses_per_col as u32,
508
+ n as u32,
509
+ ),
510
+ )?;
511
+ }
512
+ }
513
+
514
+ // 4. duty update (stage 1: no-boost write)
515
+ unsafe {
516
+ duty_fn.clone().launch(
517
+ duty_cfg,
518
+ (
519
+ &self.active_mask,
520
+ &self.raw,
521
+ &mut self.active_duty,
522
+ &mut self.overlap_duty,
523
+ &mut self.boost,
524
+ alpha,
525
+ 1.0f32,
526
+ 0.0f32,
527
+ 0.0f32,
528
+ 0u32,
529
+ n as u32,
530
+ ),
531
+ )?;
532
+ }
533
+
534
+ // 5. Boost update: fused GPU kernel (no D2H).
535
+ if learn && self.boost_strength > 0.0 {
536
+ let boost_fn = self.dev
537
+ .get_func("htm_sp_boost_fused", "sp_boost_from_duty")
538
+ .expect("sp_boost_fused not loaded");
539
+ let boost_cfg = LaunchConfig {
540
+ grid_dim: (1, 1, 1),
541
+ block_dim: (1024, 1, 1),
542
+ shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
543
+ };
544
+ unsafe {
545
+ boost_fn.launch(
546
+ boost_cfg,
547
+ (
548
+ &self.active_duty,
549
+ &mut self.boost,
550
+ self.boost_strength,
551
+ n as u32,
552
+ ),
553
+ )?;
554
+ }
555
+ }
556
+
557
+ // 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n].
558
+ let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n);
559
+ self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?;
560
+
561
+ // 7. GPU TM step: predict + activate + anomaly + learn, all on device.
562
+ tm.step(&self.active_mask, anom_dev, ti as u32, learn)?;
563
+ }
564
+
565
+ Ok(())
566
+ }
567
+
568
+ /// One SP step on the GPU. Returns sorted active-column indices.
569
+ pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> {
570
+ debug_assert_eq!(input.len(), self.input_bits);
571
+ let n = self.n_columns;
572
+ let k = ((self.sparsity * n as f32).round() as usize).max(1);
573
+
574
+ // 1. H2D input SDR.
575
+ self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?;
576
+
577
+ // 2. Launch sp_overlap: grid=n_columns, block=128.
578
+ let overlap_fn = self
579
+ .dev
580
+ .get_func("htm_sp_overlap", "sp_overlap")
581
+ .expect("sp_overlap not loaded");
582
+ let overlap_cfg = LaunchConfig {
583
+ grid_dim: (n as u32, 1, 1),
584
+ block_dim: (128, 1, 1),
585
+ shared_mem_bytes: 0,
586
+ };
587
+ unsafe {
588
+ overlap_fn.launch(
589
+ overlap_cfg,
590
+ (
591
+ &self.inp_dev,
592
+ &self.syn_bit,
593
+ &self.syn_perm,
594
+ &self.boost,
595
+ self.conn_thr,
596
+ self.synapses_per_col as u32,
597
+ n as u32,
598
+ &mut self.raw,
599
+ &mut self.boosted,
600
+ ),
601
+ )?;
602
+ }
603
+
604
+ // 3. Launch sp_topk: single block, shared mem = n_columns * f32.
605
+ let topk_fn = self
606
+ .dev
607
+ .get_func("htm_sp_topk", "sp_topk_select")
608
+ .expect("sp_topk not loaded");
609
+ let topk_cfg = LaunchConfig {
610
+ grid_dim: (1, 1, 1),
611
+ block_dim: (256, 1, 1),
612
+ shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
613
+ };
614
+ // Clear active_mask first. memset_zeros avoids an H2D of a host
615
+ // zeroes vector every step.
616
+ self.dev.memset_zeros(&mut self.active_mask)?;
617
+ unsafe {
618
+ topk_fn.launch(
619
+ topk_cfg,
620
+ (
621
+ &self.boosted,
622
+ n as u32,
623
+ k as u32,
624
+ &mut self.active_mask,
625
+ ),
626
+ )?;
627
+ }
628
+
629
+ // 4. Optional: sp_learn on active columns.
630
+ if learn {
631
+ let learn_fn = self
632
+ .dev
633
+ .get_func("htm_sp_learn", "sp_learn")
634
+ .expect("sp_learn not loaded");
635
+ let learn_cfg = LaunchConfig {
636
+ grid_dim: (n as u32, 1, 1),
637
+ block_dim: (128, 1, 1),
638
+ shared_mem_bytes: 0,
639
+ };
640
+ unsafe {
641
+ learn_fn.launch(
642
+ learn_cfg,
643
+ (
644
+ &self.active_mask,
645
+ &self.inp_dev,
646
+ &self.syn_bit,
647
+ &mut self.syn_perm,
648
+ self.inc,
649
+ self.dec,
650
+ self.synapses_per_col as u32,
651
+ n as u32,
652
+ ),
653
+ )?;
654
+ }
655
+ }
656
+
657
+ // 5. Duty cycle + boost update. Always runs (matches CPU).
658
+ // We need mean_duty on the host — compute BEFORE the update (matches
659
+ // CPU sp.rs line 200-205 where mean is computed then written).
660
+ // Actually CPU computes mean of the PRE-update duty cycles too? Re-read:
661
+ // sp.rs lines 186-196 update duty cycles (pre-mean).
662
+ // Line 202: mean = sum(active_duty_cycle) / n ← after update.
663
+ // Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)).
664
+ // So mean is on POST-update values.
665
+ // Easiest: 1) run duty update with boost_strength=0 (skip boost calc),
666
+ // 2) D2H active_duty, compute mean, 3) run a boost-only kernel
667
+ // OR inline the exp() in a second launch with mean passed.
668
+ //
669
+ // For simplicity and correctness we fuse: run the duty kernel with
670
+ // mean=0 and boost_strength=0 (disables boost write), then D2H to
671
+ // compute mean, then re-launch with the true mean. Two launches, one
672
+ // tiny D2H (n × f32). At n=2048 this is 8KB per step — negligible.
673
+ let alpha = 1.0f32 / self.duty_period.max(1.0);
674
+ let duty_fn = self
675
+ .dev
676
+ .get_func("htm_sp_duty", "sp_duty_update")
677
+ .expect("sp_duty not loaded");
678
+ let duty_cfg = LaunchConfig {
679
+ grid_dim: ((n as u32 + 255) / 256, 1, 1),
680
+ block_dim: (256, 1, 1),
681
+ shared_mem_bytes: 0,
682
+ };
683
+ // Stage 1: update duty cycles (boost_strength=0 -> no write).
684
+ unsafe {
685
+ duty_fn.launch(
686
+ duty_cfg,
687
+ (
688
+ &self.active_mask,
689
+ &self.raw,
690
+ &mut self.active_duty,
691
+ &mut self.overlap_duty,
692
+ &mut self.boost,
693
+ alpha,
694
+ 1.0f32, // stim_thr
695
+ 0.0f32, // boost_strength = 0 -> skip write
696
+ 0.0f32, // mean_duty (unused)
697
+ 0u32, // learn_flag = 0
698
+ n as u32,
699
+ ),
700
+ )?;
701
+ }
702
+
703
+ if learn && self.boost_strength > 0.0 && self.strict_parity {
704
+ // Boost update must bit-match CPU `f32::exp`, so we compute it on
705
+ // the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048.
706
+ // Critical for learning parity — CUDA expf (even without fast-math)
707
+ // uses different rounding for some inputs than host libm.
708
+ let mut duty_host = vec![0f32; n];
709
+ self.dev
710
+ .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
711
+ let sum: f32 = duty_host.iter().sum();
712
+ let mean = sum / (n as f32);
713
+ let mut boost_host = vec![0f32; n];
714
+ for i in 0..n {
715
+ boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp();
716
+ }
717
+ self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
718
+
719
+ // CPU sp.rs 210-226: permanence bump for chronically under-stimulated
720
+ // columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle),
721
+ // add inc*0.1 to every synapse of column i (clamped to 1.0).
722
+ // This runs only once per step and only for the rare cases, but we
723
+ // need it for bit-exact parity with CPU learn.
724
+ let mut ov_host = vec![0f32; n];
725
+ self.dev
726
+ .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
727
+ let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
728
+ if max_ov > 0.0 {
729
+ let thr = 0.001f32 * max_ov;
730
+ let bump = self.inc * 0.1f32;
731
+ // Find columns needing a bump. Usually empty. Rare → D2H/H2D
732
+ // of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40).
733
+ let bump_cols: Vec<u32> = ov_host
734
+ .iter()
735
+ .enumerate()
736
+ .filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None })
737
+ .collect();
738
+ if !bump_cols.is_empty() {
739
+ // Download, bump, upload. (Keeps implementation simple and
740
+ // bit-exact. Could kernelize later.)
741
+ let s = self.synapses_per_col;
742
+ let mut perm_host = vec![0f32; n * s];
743
+ self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
744
+ for &c in &bump_cols {
745
+ let base = (c as usize) * s;
746
+ for p in &mut perm_host[base..base + s] {
747
+ *p = (*p + bump).min(1.0);
748
+ }
749
+ }
750
+ self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
751
+ }
752
+ }
753
+ } else if learn && self.boost_strength > 0.0 {
754
+ // Fast path: GPU-side boost using the already-loaded duty kernel.
755
+ let mut duty_host = vec![0f32; n];
756
+ self.dev
757
+ .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
758
+ let sum: f32 = duty_host.iter().sum();
759
+ let mean = sum / (n as f32);
760
+ let boost_fn = self
761
+ .dev
762
+ .get_func("htm_sp_duty", "sp_duty_update")
763
+ .expect("sp_duty not loaded");
764
+ unsafe {
765
+ boost_fn.launch(
766
+ duty_cfg,
767
+ (
768
+ &self.active_mask,
769
+ &self.raw,
770
+ &mut self.active_duty,
771
+ &mut self.overlap_duty,
772
+ &mut self.boost,
773
+ 0.0f32,
774
+ 1.0f32,
775
+ self.boost_strength,
776
+ mean,
777
+ 1u32,
778
+ n as u32,
779
+ ),
780
+ )?;
781
+ }
782
+ }
783
+
784
+ // 6. D2H active_mask and convert to sorted index list.
785
+ self.dev
786
+ .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
787
+ let mut active: Vec<u32> = Vec::with_capacity(k);
788
+ for (i, &b) in self.host_mask.iter().enumerate() {
789
+ if b != 0 {
790
+ active.push(i as u32);
791
+ }
792
+ }
793
+ debug_assert_eq!(active.len(), k, "SP must emit exactly k winners");
794
+ Ok(active)
795
+ }
796
+ }
overlay/htm_rust/src/gpu/tests.rs ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Parity tests: GPU SP vs CPU SP reference.
2
+ //!
3
+ //! With matching seeds the two should produce bit-identical active-column sets
4
+ //! when `learn=false`, and remain bit-identical over repeated `learn=true`
5
+ //! steps because the Hebbian update is deterministic (no RNG once initialised).
6
+ //!
7
+ //! Run with: cargo test --release --features gpu
8
+
9
+ #![cfg(test)]
10
+ #![cfg(feature = "gpu")]
11
+
12
+ use crate::sp::{SpatialPooler, SpatialPoolerConfig};
13
+ use crate::gpu::sp_gpu::SpatialPoolerGpu;
14
+ use crate::gpu::tm_gpu::TemporalMemoryGpu;
15
+ use crate::gpu::fused::{
16
+ launch_fused, plan_fused_launch, FusedState,
17
+ };
18
+ use cudarc::driver::CudaSlice;
19
+ use rand::{Rng, SeedableRng};
20
+ use rand_xoshiro::Xoshiro256PlusPlus;
21
+
22
+ fn make_sdr(rng: &mut Xoshiro256PlusPlus, bits: usize, sparsity: f32) -> Vec<u8> {
23
+ let on = ((sparsity * bits as f32) as usize).max(1);
24
+ let mut v = vec![0u8; bits];
25
+ let mut placed = 0;
26
+ while placed < on {
27
+ let i = rng.gen_range(0..bits);
28
+ if v[i] == 0 {
29
+ v[i] = 1;
30
+ placed += 1;
31
+ }
32
+ }
33
+ v
34
+ }
35
+
36
+ #[test]
37
+ fn gpu_sp_matches_cpu_no_learn() {
38
+ let cfg = SpatialPoolerConfig::default();
39
+ let bits = cfg.input_bits;
40
+ let mut cpu = SpatialPooler::new(
41
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
42
+ 1234,
43
+ );
44
+ let cpu_for_gpu = SpatialPooler::new(
45
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
46
+ 1234,
47
+ );
48
+ let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu)
49
+ .expect("gpu init (CUDA device available)");
50
+ gpu.set_strict_parity(true);
51
+
52
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
53
+ for step in 0..20 {
54
+ let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
55
+ let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
56
+
57
+ let cpu_active: Vec<u32> = cpu.compute(&sdr_bool, false);
58
+ let gpu_active: Vec<u32> = gpu.compute(&sdr_u8, false).expect("gpu compute");
59
+
60
+ assert_eq!(
61
+ cpu_active, gpu_active,
62
+ "mismatch at step {step}: len cpu={} gpu={}",
63
+ cpu_active.len(), gpu_active.len()
64
+ );
65
+ }
66
+ }
67
+
68
+ #[test]
69
+ fn gpu_sp_matches_cpu_with_learn() {
70
+ let cfg = SpatialPoolerConfig::default();
71
+ let bits = cfg.input_bits;
72
+ let mut cpu = SpatialPooler::new(
73
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
74
+ 5678,
75
+ );
76
+ let cpu_for_gpu = SpatialPooler::new(
77
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
78
+ 5678,
79
+ );
80
+ let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
81
+ gpu.set_strict_parity(true);
82
+
83
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
84
+ for step in 0..50 {
85
+ let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
86
+ let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
87
+
88
+ let cpu_active = cpu.compute(&sdr_bool, true);
89
+ let gpu_active = gpu.compute(&sdr_u8, true).expect("gpu compute");
90
+
91
+ assert_eq!(
92
+ cpu_active, gpu_active,
93
+ "mismatch at step {step} with learning"
94
+ );
95
+ }
96
+ }
97
+
98
+ #[test]
99
+ fn gpu_tm_anomaly_decays_on_repeating_sequence() {
100
+ // End-to-end GPU pipeline: SP feeds TM; repeating SDR sequence should drive
101
+ // anomaly down over time.
102
+ use crate::gpu::HTMRegionGpu; // not pyclass methods; use internal constructor via Rust
103
+ // Easier: replicate the pipeline directly with SP + TM.
104
+
105
+ let cfg = SpatialPoolerConfig::default();
106
+ let bits = cfg.input_bits;
107
+ let n_cols = cfg.n_columns;
108
+ let cells_per_col = 32usize;
109
+
110
+ let cpu_for_gpu = SpatialPooler::new(SpatialPoolerConfig::default(), 314);
111
+ let mut sp = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
112
+ let dev = sp.dev_ref().clone();
113
+ let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col)
114
+ .expect("gpu tm init");
115
+ tm.reset().expect("tm reset");
116
+
117
+ // Build 3 fixed SDRs, feed them in a repeating sequence.
118
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
119
+ let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
120
+ let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
121
+
122
+ // Warm up SP so columns are stable per symbol.
123
+ for _ in 0..100 {
124
+ for s in &seqs {
125
+ let _ = sp.compute(s, true).expect("sp compute");
126
+ }
127
+ }
128
+
129
+ // Build a long input buffer: 100 repetitions of [A,B,C] = 300 steps.
130
+ let repeats = 100usize;
131
+ let t = repeats * 3;
132
+ let mut inputs_flat = vec![0u8; t * bits];
133
+ for r in 0..repeats {
134
+ for (i, s) in seqs.iter().enumerate() {
135
+ let off = (r * 3 + i) * bits;
136
+ inputs_flat[off..off + bits].copy_from_slice(s);
137
+ }
138
+ }
139
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod");
140
+
141
+ let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc cols");
142
+ let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
143
+
144
+ sp.step_batch_with_tm(
145
+ &inputs_dev,
146
+ t,
147
+ bits,
148
+ true,
149
+ &mut cols_dev,
150
+ &mut anom_dev,
151
+ &mut tm,
152
+ ).expect("step_batch_with_tm");
153
+
154
+ let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
155
+ let cols: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
156
+
157
+ // Active column count per step must equal k for every step.
158
+ let k = ((cfg.sparsity * n_cols as f32).round() as usize).max(1);
159
+ for ti in 0..t {
160
+ let step_slice = &cols[ti * n_cols..(ti + 1) * n_cols];
161
+ let n_on = step_slice.iter().filter(|&&b| b != 0).count();
162
+ assert_eq!(n_on, k, "step {ti} has {n_on} active cols, expected {k}");
163
+ }
164
+
165
+ // First repetition: anomaly should be near 1.0 (nothing predicted).
166
+ let early_avg: f32 = anom[3..9].iter().sum::<f32>() / 6.0;
167
+ // Last repetitions: anomaly should be noticeably lower.
168
+ let late_avg: f32 = anom[(t - 9)..t].iter().sum::<f32>() / 9.0;
169
+ eprintln!("gpu tm: early anomaly = {early_avg:.3}, late = {late_avg:.3}");
170
+ assert!(
171
+ late_avg < early_avg,
172
+ "GPU TM should reduce anomaly on repeating sequence: early={early_avg:.3}, late={late_avg:.3}"
173
+ );
174
+ }
175
+
176
+ /// Cluster-sync smoke test: verifies that the fused megakernel (which relies on
177
+ /// hardware `cluster::sync()` / grid-barrier on H100/H200 Hopper) completes
178
+ /// without deadlock when called with real HTM state, and that output shapes are
179
+ /// sane (no NaN / Inf in anomaly scores, active-column count in plausible range).
180
+ ///
181
+ /// This is an *integration* test, not a synthetic micro-benchmark: it exercises
182
+ /// exactly the same `launch_fused` code path used in production, so any
183
+ /// deadlock in the cooperative-grid or DLB barrier would surface here.
184
+ ///
185
+ /// Skips gracefully (with an eprintln) when no GPU is available — the test
186
+ /// binary returns exit-code 0 in that case so CI still passes.
187
+ #[test]
188
+ fn cluster_sync_smoke_test() {
189
+ // Build a tiny HTM region (1024 inputs, 256 columns, 4 cells/column).
190
+ // This keeps VRAM usage minimal while still exercising all kernel paths.
191
+ let input_bits = 1024usize;
192
+ let n_columns = 256usize;
193
+ let cells_per_col = 4usize;
194
+
195
+ // Probe cooperative launch attribute before doing any real work.
196
+ // CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 223 (added in CUDA 11.8 for Hopper).
197
+ // cudarc exposes raw attribute querying; we check cooperative launch (98)
198
+ // as the guard — cluster launch is a superset and not separately probed
199
+ // here since cudarc doesn't expose attribute 223 symbolically yet.
200
+ // On pre-Hopper hardware the DLB barrier path is used instead and the
201
+ // test still validates no deadlock on that path.
202
+
203
+ let make_cfg = || SpatialPoolerConfig {
204
+ input_bits,
205
+ n_columns,
206
+ sparsity: 0.04, // ~10 active cols out of 256
207
+ ..SpatialPoolerConfig::default()
208
+ };
209
+
210
+ let cpu_ref = SpatialPooler::new(make_cfg(), 42);
211
+
212
+ let mut sp = match SpatialPoolerGpu::from_cpu(&cpu_ref) {
213
+ Ok(sp) => sp,
214
+ Err(e) => {
215
+ eprintln!("[cluster_sync_smoke_test] No GPU available ({e:?}) — skipping");
216
+ return;
217
+ }
218
+ };
219
+
220
+ let dev = sp.dev_ref().clone();
221
+
222
+ // Check cooperative launch support; skip with a clear message if absent.
223
+ let cooperative_ok = matches!(
224
+ dev.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
225
+ Ok(v) if v > 0
226
+ );
227
+ if !cooperative_ok {
228
+ eprintln!("[cluster_sync_smoke_test] CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH=0 — DLB path only, still running test");
229
+ // We continue — the DLB path is the production fallback and must not deadlock either.
230
+ }
231
+
232
+ let mut tm = match TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_col) {
233
+ Ok(tm) => tm,
234
+ Err(e) => {
235
+ eprintln!("[cluster_sync_smoke_test] TemporalMemoryGpu::new failed ({e:?}) — skipping");
236
+ return;
237
+ }
238
+ };
239
+ tm.reset().expect("tm reset");
240
+
241
+ let mut fused_st: FusedState = match FusedState::new(
242
+ dev.clone(),
243
+ n_columns,
244
+ cells_per_col,
245
+ sp.initial_threshold_estimate(),
246
+ ) {
247
+ Ok(f) => f,
248
+ Err(e) => {
249
+ eprintln!("[cluster_sync_smoke_test] FusedState::new failed ({e:?}) — skipping");
250
+ return;
251
+ }
252
+ };
253
+ fused_st.reset().expect("fused reset");
254
+
255
+ // Build T=4 timesteps of all-zero input SDRs.
256
+ let t = 4usize;
257
+ let inputs_flat = vec![0u8; t * input_bits];
258
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod inputs");
259
+
260
+ let mut cols_dev = dev.alloc_zeros::<u8>(t * n_columns).expect("alloc cols");
261
+ let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
262
+
263
+ // Execute with a 2-second timeout guard via a thread. If the kernel
264
+ // deadlocks, the parent test process times out and the CI job reports
265
+ // failure — we can't cancel a live CUDA kernel from Rust, but the
266
+ // launch_fused call itself must return within this window on any sane GPU.
267
+ //
268
+ // We run the kernel inline (not in a separate thread) because CUDA contexts
269
+ // are not safely shareable across threads without explicit multi-threading
270
+ // setup. The 2-second bound is enforced implicitly: if the kernel deadlocks,
271
+ // the test binary will hang and the CI timeout (typically 5 min) will kill it.
272
+ // For local dev, the deadlock would be immediately obvious.
273
+
274
+ launch_fused(
275
+ &mut sp,
276
+ &mut tm,
277
+ &mut fused_st,
278
+ &inputs_dev,
279
+ &mut cols_dev,
280
+ &mut anom_dev,
281
+ t,
282
+ input_bits,
283
+ false, // learn=false for determinism
284
+ ).expect("launch_fused (cluster_sync_smoke_test): deadlock or CUDA error");
285
+
286
+ dev.synchronize().expect("device sync after launch_fused");
287
+
288
+ // --- Correctness assertions ---
289
+
290
+ let cols_host: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
291
+ let anom_host: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
292
+
293
+ // Output buffers must be exactly the right size.
294
+ assert_eq!(cols_host.len(), t * n_columns, "cols buffer size mismatch");
295
+ assert_eq!(anom_host.len(), t, "anom buffer size mismatch");
296
+
297
+ // Anomaly scores must be finite (NaN/Inf indicates numerical blow-up).
298
+ for (i, &a) in anom_host.iter().enumerate() {
299
+ assert!(a.is_finite(), "anomaly[{i}] is not finite: {a}");
300
+ assert!(a >= 0.0 && a <= 1.0, "anomaly[{i}] out of [0,1]: {a}");
301
+ }
302
+
303
+ // Active-column count per step: threshold-based inhibition, so 0 is
304
+ // possible on cold start (before thresholds calibrate), but we assert
305
+ // <= n_columns to catch buffer overruns or completely wrong output.
306
+ for ti in 0..t {
307
+ let n_on = cols_host[ti * n_columns..(ti + 1) * n_columns]
308
+ .iter()
309
+ .filter(|&&b| b != 0)
310
+ .count();
311
+ assert!(
312
+ n_on <= n_columns,
313
+ "step {ti}: active columns {n_on} > n_columns {n_columns} (buffer overrun?)"
314
+ );
315
+ }
316
+
317
+ eprintln!(
318
+ "[cluster_sync_smoke_test] PASSED: T={t}, n_cols={n_columns}, \
319
+ input_bits={input_bits}, cooperative_supported={cooperative_ok}, \
320
+ anom={anom_host:?}"
321
+ );
322
+ }
323
+
324
+ /// Parity check: the CAI zero-copy path (`step_many_cuda`) must produce
325
+ /// bit-identical outputs to the numpy H2D/D2H path (`step_batch_with_tm`),
326
+ /// since the kernel pipeline is the same — only the I/O wrapping changes.
327
+ /// We skip the PyO3 CAI dict plumbing here and test the underlying
328
+ /// ManuallyDrop + upgrade_device_ptr pattern directly.
329
+ #[test]
330
+ fn gpu_cuda_vs_numpy_parity() {
331
+ use std::mem::ManuallyDrop;
332
+
333
+ let cfg = SpatialPoolerConfig::default();
334
+ let bits = cfg.input_bits;
335
+ let n_cols = cfg.n_columns;
336
+ let cells_per_col = 32usize;
337
+
338
+ // Build two identical (SP, TM) pairs from the same seed.
339
+ let build = || -> (SpatialPoolerGpu, TemporalMemoryGpu) {
340
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271828);
341
+ let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu init");
342
+ let dev = sp.dev_ref().clone();
343
+ let mut tm = TemporalMemoryGpu::new(dev, n_cols, cells_per_col).expect("tm init");
344
+ tm.reset().expect("tm reset");
345
+ (sp, tm)
346
+ };
347
+
348
+ // Deterministic SDR sequence.
349
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
350
+ let t = 32usize;
351
+ let mut inputs_flat = vec![0u8; t * bits];
352
+ for i in 0..t {
353
+ let sdr = make_sdr(&mut rng, bits, 0.02);
354
+ inputs_flat[i * bits..(i + 1) * bits].copy_from_slice(&sdr);
355
+ }
356
+
357
+ // ---- Path A: owned CudaSlice (numpy-equivalent path) ----
358
+ let (mut sp_a, mut tm_a) = build();
359
+ let dev_a = sp_a.dev_ref().clone();
360
+ let inputs_a: CudaSlice<u8> = dev_a.htod_sync_copy(&inputs_flat).expect("htod");
361
+ let mut cols_a = dev_a.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_a");
362
+ let mut anom_a = dev_a.alloc_zeros::<f32>(t).expect("alloc anom_a");
363
+ sp_a.step_batch_with_tm(&inputs_a, t, bits, false, &mut cols_a, &mut anom_a, &mut tm_a)
364
+ .expect("owned step_batch_with_tm");
365
+ dev_a.synchronize().expect("sync a");
366
+ let cols_a_host: Vec<u8> = dev_a.dtoh_sync_copy(&cols_a).expect("d2h cols_a");
367
+ let anom_a_host: Vec<f32> = dev_a.dtoh_sync_copy(&anom_a).expect("d2h anom_a");
368
+
369
+ // ---- Path B: borrowed device pointers via upgrade_device_ptr ----
370
+ // We allocate fresh owned CudaSlices on a fresh device, then take their
371
+ // raw ptrs and re-wrap as ManuallyDrop borrowed views — mimicking what
372
+ // `step_many_cuda` does with torch-owned CUDA memory.
373
+ let (mut sp_b, mut tm_b) = build();
374
+ let dev_b = sp_b.dev_ref().clone();
375
+ let inputs_b_owned: CudaSlice<u8> = dev_b.htod_sync_copy(&inputs_flat).expect("htod");
376
+ let cols_b_owned = dev_b.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_b");
377
+ let anom_b_owned = dev_b.alloc_zeros::<f32>(t).expect("alloc anom_b");
378
+
379
+ // Extract raw CUdeviceptrs (and leak the owners so their Drop doesn't free).
380
+ let inputs_ptr = inputs_b_owned.leak();
381
+ let cols_ptr = cols_b_owned.leak();
382
+ let anom_ptr = anom_b_owned.leak();
383
+
384
+ // Re-wrap as borrowed views.
385
+ let inputs_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) });
386
+ let mut cols_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) });
387
+ let mut anom_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) });
388
+
389
+ sp_b.step_batch_with_tm(&inputs_b, t, bits, false, &mut cols_b, &mut anom_b, &mut tm_b)
390
+ .expect("borrowed step_batch_with_tm");
391
+ dev_b.synchronize().expect("sync b");
392
+ // `ManuallyDrop` doesn't auto-coerce to `&CudaSlice<T>` for the DevicePtr
393
+ // trait bound on `dtoh_sync_copy`; explicit deref.
394
+ let cols_b_host: Vec<u8> = dev_b.dtoh_sync_copy(&*cols_b).expect("d2h cols_b");
395
+ let anom_b_host: Vec<f32> = dev_b.dtoh_sync_copy(&*anom_b).expect("d2h anom_b");
396
+
397
+ // Re-own so Drop actually frees (we leaked above).
398
+ let _inputs_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) };
399
+ let _cols_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) };
400
+ let _anom_owned_again = unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) };
401
+
402
+ assert_eq!(cols_a_host, cols_b_host, "active-column mask diverges between numpy and CAI paths");
403
+ assert_eq!(anom_a_host.len(), anom_b_host.len());
404
+ for (i, (a, b)) in anom_a_host.iter().zip(anom_b_host.iter()).enumerate() {
405
+ // Anomaly is a pure division of integer counts — bit-exact expected.
406
+ assert!((a - b).abs() < 1e-7, "anomaly mismatch at step {i}: a={a} b={b}");
407
+ }
408
+ }
409
+
410
+ /// Fused kernel: threshold activation should converge to near target sparsity
411
+ /// after a short warmup. Acceptance: mean activation rate per step lands in
412
+ /// [0.3*target, 2.5*target] after 500-step warmup. Because the threshold
413
+ /// starts conservative (=2.0) and the per-column adaptation rate is slow
414
+ /// (0.001), we allow a generous band — the test asserts directional
415
+ /// convergence toward the target, not tight matching.
416
+ #[test]
417
+ fn gpu_threshold_converges_to_sparsity() {
418
+ let cfg = SpatialPoolerConfig::default();
419
+ let bits = cfg.input_bits;
420
+ let n_cols = cfg.n_columns;
421
+ let cells_per_col = 32usize;
422
+ let target = cfg.sparsity; // 0.02 = 40 cols expected
423
+
424
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 111);
425
+ let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
426
+ let dev = sp.dev_ref().clone();
427
+ let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
428
+ let mut fused = FusedState::new(
429
+ dev.clone(),
430
+ n_cols,
431
+ cells_per_col,
432
+ sp.initial_threshold_estimate(),
433
+ ).expect("fused init");
434
+ tm.reset().expect("tm reset");
435
+ fused.reset().expect("fused reset");
436
+
437
+ // Warmup: 1000 random 2%-sparse SDRs.
438
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
439
+ let t_warm = 1000usize;
440
+ let mut inputs = vec![0u8; t_warm * bits];
441
+ for ti in 0..t_warm {
442
+ let sdr = make_sdr(&mut rng, bits, 0.02);
443
+ inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
444
+ }
445
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod");
446
+ let mut cols_dev = dev.alloc_zeros::<u8>(t_warm * n_cols).expect("alloc cols");
447
+ let mut anom_dev = dev.alloc_zeros::<f32>(t_warm).expect("alloc anom");
448
+ launch_fused(
449
+ &mut sp, &mut tm, &mut fused,
450
+ &inputs_dev, &mut cols_dev, &mut anom_dev,
451
+ t_warm, bits, true,
452
+ ).expect("warmup launch");
453
+ dev.synchronize().expect("sync");
454
+
455
+ // Measurement pass: another 200 steps, measure mean activation.
456
+ let t_meas = 200usize;
457
+ let mut meas_inputs = vec![0u8; t_meas * bits];
458
+ for ti in 0..t_meas {
459
+ let sdr = make_sdr(&mut rng, bits, 0.02);
460
+ meas_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
461
+ }
462
+ let meas_dev: CudaSlice<u8> = dev.htod_sync_copy(&meas_inputs).expect("htod meas");
463
+ let mut meas_cols = dev.alloc_zeros::<u8>(t_meas * n_cols).expect("alloc meas cols");
464
+ let mut meas_anom = dev.alloc_zeros::<f32>(t_meas).expect("alloc meas anom");
465
+ launch_fused(
466
+ &mut sp, &mut tm, &mut fused,
467
+ &meas_dev, &mut meas_cols, &mut meas_anom,
468
+ t_meas, bits, true,
469
+ ).expect("meas launch");
470
+ dev.synchronize().expect("sync meas");
471
+
472
+ let cols_host: Vec<u8> = dev.dtoh_sync_copy(&meas_cols).expect("d2h");
473
+ let mut step_counts = Vec::with_capacity(t_meas);
474
+ for ti in 0..t_meas {
475
+ let n_on = cols_host[ti*n_cols..(ti+1)*n_cols]
476
+ .iter().filter(|&&b| b != 0).count();
477
+ step_counts.push(n_on);
478
+ }
479
+ let mean_active: f64 = step_counts.iter().map(|&c| c as f64).sum::<f64>()
480
+ / (t_meas as f64);
481
+ let target_active = target as f64 * n_cols as f64;
482
+ eprintln!(
483
+ "threshold-activation convergence: mean_active/step = {mean_active:.1} \
484
+ (target = {target_active:.1})"
485
+ );
486
+ // Very generous band — we just want to confirm the threshold loop is
487
+ // functioning (not diverged to 0 or to all-active).
488
+ assert!(
489
+ mean_active >= 0.25 * target_active && mean_active <= 4.0 * target_active,
490
+ "mean active {mean_active:.1} outside [0.25x, 4x] of target {target_active:.1}"
491
+ );
492
+ }
493
+
494
+ /// Fused kernel: TM should learn a repeating sequence — anomaly decays.
495
+ #[test]
496
+ fn gpu_fused_tm_anomaly_decays_on_repeating_sequence() {
497
+ let cfg = SpatialPoolerConfig::default();
498
+ let bits = cfg.input_bits;
499
+ let n_cols = cfg.n_columns;
500
+ let cells_per_col = 32usize;
501
+
502
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271);
503
+ let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
504
+ let dev = sp.dev_ref().clone();
505
+ let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
506
+ let mut fused = FusedState::new(
507
+ dev.clone(),
508
+ n_cols,
509
+ cells_per_col,
510
+ sp.initial_threshold_estimate(),
511
+ ).expect("fused init");
512
+ tm.reset().expect("tm reset");
513
+ fused.reset().expect("fused reset");
514
+
515
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
516
+ let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
517
+ let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
518
+
519
+ // Warmup SP threshold calibration with random SDRs first.
520
+ let warm = 300usize;
521
+ let mut warm_inputs = vec![0u8; warm * bits];
522
+ for ti in 0..warm {
523
+ let sdr = make_sdr(&mut rng, bits, 0.02);
524
+ warm_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
525
+ }
526
+ let warm_dev: CudaSlice<u8> = dev.htod_sync_copy(&warm_inputs).expect("htod warm");
527
+ let mut warm_cols = dev.alloc_zeros::<u8>(warm * n_cols).expect("alloc warm cols");
528
+ let mut warm_anom = dev.alloc_zeros::<f32>(warm).expect("alloc warm anom");
529
+ launch_fused(
530
+ &mut sp, &mut tm, &mut fused,
531
+ &warm_dev, &mut warm_cols, &mut warm_anom,
532
+ warm, bits, true,
533
+ ).expect("warm launch");
534
+ dev.synchronize().expect("sync warm");
535
+
536
+ // Feed repeating A,B,C sequence for 100 reps.
537
+ let repeats = 100usize;
538
+ let t = repeats * 3;
539
+ let mut inputs = vec![0u8; t * bits];
540
+ for r in 0..repeats {
541
+ for (i, s) in seqs.iter().enumerate() {
542
+ let off = (r*3 + i) * bits;
543
+ inputs[off..off+bits].copy_from_slice(s);
544
+ }
545
+ }
546
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod rep");
547
+ let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc rep cols");
548
+ let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc rep anom");
549
+ launch_fused(
550
+ &mut sp, &mut tm, &mut fused,
551
+ &inputs_dev, &mut cols_dev, &mut anom_dev,
552
+ t, bits, true,
553
+ ).expect("rep launch");
554
+ dev.synchronize().expect("sync rep");
555
+
556
+ let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
557
+ let early_avg: f32 = anom[3..12].iter().sum::<f32>() / 9.0;
558
+ let late_avg: f32 = anom[(t-9)..t].iter().sum::<f32>() / 9.0;
559
+ eprintln!("fused TM anomaly: early={early_avg:.3} late={late_avg:.3}");
560
+ assert!(
561
+ late_avg < early_avg,
562
+ "anomaly must decay: early={early_avg:.3} late={late_avg:.3}"
563
+ );
564
+ assert!(
565
+ late_avg < 0.5,
566
+ "late anomaly must be < 0.5 (got {late_avg:.3})"
567
+ );
568
+ }
569
+
570
+ #[test]
571
+ fn gpu_sp_yields_k_winners() {
572
+ let cfg = SpatialPoolerConfig::default();
573
+ let bits = cfg.input_bits;
574
+ let n = cfg.n_columns;
575
+ let expected_k = ((cfg.sparsity * n as f32).round() as usize).max(1);
576
+ let cpu = SpatialPooler::new(SpatialPoolerConfig::default(), 7);
577
+ let mut gpu = SpatialPoolerGpu::from_cpu(&cpu).expect("gpu init");
578
+
579
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
580
+ for _ in 0..10 {
581
+ let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
582
+ let active = gpu.compute(&sdr_u8, false).expect("gpu compute");
583
+ assert_eq!(active.len(), expected_k);
584
+ // Ensure sorted + unique.
585
+ for w in active.windows(2) {
586
+ assert!(w[0] < w[1], "duplicate or out-of-order winner indices");
587
+ }
588
+ }
589
+ }
590
+
591
+ #[test]
592
+ fn fused_launch_plan_uses_cooperative_grid_sync() {
593
+ let plan = plan_fused_launch(30, true, 30, None).expect("cooperative supported");
594
+ assert_eq!(plan.grid_dim_x, 16);
595
+ assert_eq!(plan.cooperative_grid_limit, 30);
596
+ }
597
+
598
+ #[test]
599
+ fn fused_launch_plan_scales_to_big_gpu() {
600
+ // H200-like: 132 SMs, high cooperative_grid_limit. Cap still applies.
601
+ let plan = plan_fused_launch(132, true, 1000, None).expect("cooperative supported");
602
+ assert_eq!(plan.grid_dim_x, 16); // capped by default override
603
+ let plan = plan_fused_launch(132, true, 1000, Some(64)).expect("cooperative supported");
604
+ assert_eq!(plan.grid_dim_x, 64); // override raises the cap
605
+ }
606
+
607
+ #[test]
608
+ fn fused_launch_plan_refuses_non_cooperative_devices() {
609
+ // The slow path was removed. Devices without cooperative launch fail fast.
610
+ let err = plan_fused_launch(30, false, 0, None).unwrap_err();
611
+ assert!(err.contains("cooperative launch"));
612
+ }
613
+
614
+ #[test]
615
+ fn fused_grid_cap_env_override_is_honored() {
616
+ let cfg = SpatialPoolerConfig::default();
617
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 5252);
618
+ let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
619
+ let dev = sp.dev_ref().clone();
620
+
621
+ unsafe { std::env::set_var("HTM_FUSED_GRID_CAP", "12"); }
622
+ let fused = FusedState::new(
623
+ dev.clone(),
624
+ cfg.n_columns,
625
+ 32usize,
626
+ sp.initial_threshold_estimate(),
627
+ ).expect("fused init");
628
+ unsafe { std::env::remove_var("HTM_FUSED_GRID_CAP"); }
629
+
630
+ let sm_count = match dev.attribute(
631
+ cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
632
+ ) {
633
+ Ok(v) => v as u32,
634
+ Err(_) => 16u32,
635
+ };
636
+ let expected = sm_count.max(1).min(12);
637
+ assert_eq!(
638
+ fused.grid_dim_x,
639
+ expected,
640
+ "fused grid cap env override ignored: expected min(sm_count, 12) = {expected}, got {}",
641
+ fused.grid_dim_x,
642
+ );
643
+ }
overlay/htm_rust/src/gpu/tm_gpu.rs ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! GPU Temporal Memory.
2
+ //!
3
+ //! Flat device storage. Pre-allocated segment slab:
4
+ //! n_cells = n_columns * cells_per_column
5
+ //! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
6
+ //! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
7
+ //!
8
+ //! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
9
+ //! MAX_SEGMENTS_PER_CELL = 16
10
+ //! MAX_SYN_PER_SEGMENT = 32
11
+ //!
12
+ //! At n_cells = 65536:
13
+ //! n_segments_max = 1_048_576 (~1M)
14
+ //! n_synapses_max = 33_554_432 (~33M)
15
+ //! Storage:
16
+ //! syn_presyn : u32 × 33M = 128 MB
17
+ //! syn_perm : i16 × 33M = 64 MB
18
+ //! seg_cell : u32 × 1M = 4 MB
19
+ //! seg_syn_n : u32 × 1M = 4 MB
20
+ //! misc bitsets etc ~ <1 MB
21
+ //! -------------------------------
22
+ //! Total per region ~200 MB
23
+ //!
24
+ //! Permanences are stored as i16 scaled by 32767 (→ [0, 32767] represents
25
+ //! [0.0, 1.0]). inc/dec are provided pre-scaled.
26
+
27
+ use std::sync::Arc;
28
+
29
+ use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
30
+ use cudarc::nvrtc::Ptx;
31
+
32
+ /// Packed config struct passed by value to TM kernels to stay under
33
+ /// cudarc's 12-tuple launch limit. Layout must match the C-side
34
+ /// `TmConfig` struct declared in each kernel.
35
+ #[repr(C)]
36
+ #[derive(Clone, Copy)]
37
+ pub struct TmConfig {
38
+ pub activation_threshold: u32,
39
+ pub learning_threshold: u32,
40
+ pub cells_per_column: u32,
41
+ pub synapses_per_segment: u32,
42
+ pub n_segments: u32,
43
+ pub n_cells: u32,
44
+ pub max_segments_per_cell: u32,
45
+ pub max_new_synapses: u32,
46
+ pub conn_thr_i16: i32, // i16 widened to i32 for alignment
47
+ pub perm_inc_i16: i32,
48
+ pub perm_dec_i16: i32,
49
+ pub predicted_seg_dec_i16: i32,
50
+ pub initial_perm_i16: i32,
51
+ pub iter_seed: u32,
52
+ pub n_cols: u32,
53
+ pub bits_words: u32,
54
+ }
55
+
56
+ unsafe impl DeviceRepr for TmConfig {}
57
+
58
+ // Embedded PTX.
59
+ const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
60
+ const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
61
+ const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
62
+ const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
63
+ const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
64
+ const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
65
+ const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
66
+
67
+ /// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
68
+ /// n_cells = 2048 × 32 = 65_536
69
+ /// n_segments_max = n_cells × MAX_SEGMENTS_PER_CELL
70
+ /// n_synapses_max = n_segments_max × MAX_SYN_PER_SEGMENT
71
+ ///
72
+ /// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
73
+ /// The training loop runs with `reset_each_forward=True`, so segment counts
74
+ /// per window stay well below 32K (typical: ~n_cols new segs per step until
75
+ /// the first matching segment is reused; in a 2048-step window that plateaus
76
+ /// around ~5K total live segments). The 262K ceiling is generous headroom.
77
+ pub const MAX_SEGMENTS_PER_CELL: usize = 4;
78
+ pub const MAX_SYN_PER_SEGMENT: usize = 20;
79
+
80
+ const PERM_SCALE: f32 = 32767.0;
81
+
82
+ fn perm_f32_to_i16(x: f32) -> i16 {
83
+ let clamped = x.clamp(0.0, 1.0);
84
+ (clamped * PERM_SCALE).round() as i16
85
+ }
86
+
87
+ pub struct TemporalMemoryGpu {
88
+ dev: Arc<CudaDevice>,
89
+
90
+ // Config mirror
91
+ pub n_columns: usize,
92
+ pub cells_per_column: usize,
93
+ pub activation_threshold: u32,
94
+ pub learning_threshold: u32,
95
+ pub initial_perm_i16: i16,
96
+ pub conn_thr_i16: i16,
97
+ pub perm_inc_i16: i16,
98
+ pub perm_dec_i16: i16,
99
+ pub predicted_seg_dec_i16: i16,
100
+ pub max_new_synapse_count: u32,
101
+
102
+ // Sizes
103
+ pub n_cells: usize,
104
+ pub n_segments_max: usize,
105
+ pub bits_words: usize, // n_cells / 32
106
+
107
+ // Persistent device buffers
108
+ seg_cell_id: CudaSlice<u32>,
109
+ seg_syn_count: CudaSlice<u32>,
110
+ syn_presyn: CudaSlice<u32>,
111
+ syn_perm: CudaSlice<i16>,
112
+ cell_seg_count: CudaSlice<u32>,
113
+
114
+ cell_active_bits: CudaSlice<u32>,
115
+ cell_winner_bits: CudaSlice<u32>,
116
+ cell_predictive_bits: CudaSlice<u32>,
117
+ prev_active_bits: CudaSlice<u32>,
118
+ prev_winner_bits: CudaSlice<u32>,
119
+
120
+ col_predicted: CudaSlice<u8>,
121
+ seg_num_active_conn: CudaSlice<u32>,
122
+ seg_num_active_pot: CudaSlice<u32>,
123
+ unpredicted_count: CudaSlice<u32>,
124
+ burst_cols_flat: CudaSlice<u32>,
125
+ burst_cols_count: CudaSlice<u32>,
126
+ col_best_match: CudaSlice<u32>,
127
+
128
+ iter_counter: u32,
129
+ }
130
+
131
+ impl TemporalMemoryGpu {
132
+ pub fn new(
133
+ dev: Arc<CudaDevice>,
134
+ n_columns: usize,
135
+ cells_per_column: usize,
136
+ ) -> Result<Self, DriverError> {
137
+ let n_cells = n_columns * cells_per_column;
138
+ assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
139
+ let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
140
+ let bits_words = n_cells / 32;
141
+
142
+ // Numenta defaults.
143
+ let activation_threshold = 15u32;
144
+ let learning_threshold = 13u32;
145
+ let initial_perm_i16 = perm_f32_to_i16(0.21);
146
+ let conn_thr_i16 = perm_f32_to_i16(0.50);
147
+ let perm_inc_i16 = perm_f32_to_i16(0.10);
148
+ let perm_dec_i16 = perm_f32_to_i16(0.10);
149
+ let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
150
+ let max_new_synapse_count = 20u32;
151
+
152
+ // Allocate buffers.
153
+ let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
154
+ let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
155
+ let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
156
+ let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
157
+ let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
158
+ let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
159
+
160
+ let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
161
+ let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
162
+ let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
163
+ let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
164
+ let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
165
+
166
+ let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
167
+ let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
168
+ let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
169
+ let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
170
+ // Bursting columns for one step bounded by n_columns.
171
+ let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
172
+ let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
173
+ let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
174
+
175
+ // Load PTX modules.
176
+ let modules = [
177
+ ("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
178
+ ("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
179
+ ("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
180
+ ("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
181
+ ("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
182
+ ("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
183
+ ("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
184
+ ];
185
+ for (modname, ptx, fnname) in modules {
186
+ if dev.get_func(modname, fnname).is_none() {
187
+ dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
188
+ }
189
+ }
190
+
191
+ Ok(Self {
192
+ dev,
193
+ n_columns,
194
+ cells_per_column,
195
+ activation_threshold,
196
+ learning_threshold,
197
+ initial_perm_i16,
198
+ conn_thr_i16,
199
+ perm_inc_i16,
200
+ perm_dec_i16,
201
+ predicted_seg_dec_i16,
202
+ max_new_synapse_count,
203
+ n_cells,
204
+ n_segments_max,
205
+ bits_words,
206
+ seg_cell_id,
207
+ seg_syn_count,
208
+ syn_presyn,
209
+ syn_perm,
210
+ cell_seg_count,
211
+ cell_active_bits,
212
+ cell_winner_bits,
213
+ cell_predictive_bits,
214
+ prev_active_bits,
215
+ prev_winner_bits,
216
+ col_predicted,
217
+ seg_num_active_conn,
218
+ seg_num_active_pot,
219
+ unpredicted_count,
220
+ burst_cols_flat,
221
+ burst_cols_count,
222
+ col_best_match,
223
+ iter_counter: 0,
224
+ })
225
+ }
226
+
227
+ // --- Fused-path accessors ---
228
+ pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
229
+ pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
230
+ pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
231
+ pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
232
+ pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
233
+
234
+ /// Hard reset — clear everything (predictive + active + segments).
235
+ pub fn reset(&mut self) -> Result<(), DriverError> {
236
+ // Restore "unused" sentinel in seg_cell_id.
237
+ let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
238
+ self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
239
+ self.dev.memset_zeros(&mut self.seg_syn_count)?;
240
+ self.dev.memset_zeros(&mut self.cell_seg_count)?;
241
+ self.dev.memset_zeros(&mut self.cell_active_bits)?;
242
+ self.dev.memset_zeros(&mut self.cell_winner_bits)?;
243
+ self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
244
+ self.dev.memset_zeros(&mut self.prev_active_bits)?;
245
+ self.dev.memset_zeros(&mut self.prev_winner_bits)?;
246
+ self.dev.memset_zeros(&mut self.col_best_match)?;
247
+ self.iter_counter = 0;
248
+ Ok(())
249
+ }
250
+
251
+ fn build_cfg(&self) -> TmConfig {
252
+ TmConfig {
253
+ activation_threshold: self.activation_threshold,
254
+ learning_threshold: self.learning_threshold,
255
+ cells_per_column: self.cells_per_column as u32,
256
+ synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
257
+ n_segments: self.n_segments_max as u32,
258
+ n_cells: self.n_cells as u32,
259
+ max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
260
+ max_new_synapses: self.max_new_synapse_count,
261
+ conn_thr_i16: self.conn_thr_i16 as i32,
262
+ perm_inc_i16: self.perm_inc_i16 as i32,
263
+ perm_dec_i16: self.perm_dec_i16 as i32,
264
+ predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
265
+ initial_perm_i16: self.initial_perm_i16 as i32,
266
+ iter_seed: self.iter_counter,
267
+ n_cols: self.n_columns as u32,
268
+ bits_words: self.bits_words as u32,
269
+ }
270
+ }
271
+
272
+ /// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
273
+ /// on device) and writes `anomaly_out[t_slot]`.
274
+ pub fn step(
275
+ &mut self,
276
+ sp_active_mask: &CudaSlice<u8>,
277
+ anomaly_out: &mut CudaSlice<f32>,
278
+ t_slot: u32,
279
+ learn: bool,
280
+ ) -> Result<(), DriverError> {
281
+ let n_cells = self.n_cells;
282
+ let n_cols = self.n_columns;
283
+
284
+ let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
285
+ let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
286
+ let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
287
+ let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
288
+ let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
289
+ let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
290
+ let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
291
+
292
+ self.iter_counter = self.iter_counter.wrapping_add(1);
293
+ let cfg_val = self.build_cfg();
294
+
295
+ // 0. Per-step reset.
296
+ let reset_words = self.bits_words.max(n_cols);
297
+ let reset_cfg = LaunchConfig {
298
+ grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
299
+ block_dim: (256, 1, 1),
300
+ shared_mem_bytes: 0,
301
+ };
302
+ unsafe {
303
+ reset_fn.clone().launch(
304
+ reset_cfg,
305
+ (
306
+ &mut self.cell_active_bits,
307
+ &mut self.cell_winner_bits,
308
+ &mut self.cell_predictive_bits,
309
+ &mut self.prev_active_bits,
310
+ &mut self.prev_winner_bits,
311
+ &mut self.col_predicted,
312
+ &mut self.unpredicted_count,
313
+ &mut self.burst_cols_count,
314
+ &mut self.col_best_match,
315
+ self.bits_words as u32,
316
+ n_cols as u32,
317
+ ),
318
+ )?;
319
+ }
320
+
321
+ // 1. Predict (grid = n_cells; each block iterates its cell's segments).
322
+ let predict_cfg = LaunchConfig {
323
+ grid_dim: (n_cells as u32, 1, 1),
324
+ block_dim: (32, 1, 1),
325
+ shared_mem_bytes: 0,
326
+ };
327
+ unsafe {
328
+ predict_fn.clone().launch(
329
+ predict_cfg,
330
+ (
331
+ &self.seg_cell_id,
332
+ &self.seg_syn_count,
333
+ &self.syn_presyn,
334
+ &self.syn_perm,
335
+ &self.prev_active_bits,
336
+ &mut self.cell_predictive_bits,
337
+ &mut self.col_predicted,
338
+ &mut self.seg_num_active_conn,
339
+ &mut self.seg_num_active_pot,
340
+ &mut self.col_best_match,
341
+ &self.cell_seg_count,
342
+ cfg_val,
343
+ ),
344
+ )?;
345
+ }
346
+
347
+ // 2. Activate.
348
+ let activate_cfg = LaunchConfig {
349
+ grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
350
+ block_dim: (256, 1, 1),
351
+ shared_mem_bytes: 0,
352
+ };
353
+ unsafe {
354
+ activate_fn.clone().launch(
355
+ activate_cfg,
356
+ (
357
+ sp_active_mask,
358
+ &self.col_predicted,
359
+ &self.cell_predictive_bits,
360
+ &mut self.cell_active_bits,
361
+ &mut self.cell_winner_bits,
362
+ &mut self.unpredicted_count,
363
+ &mut self.burst_cols_flat,
364
+ &mut self.burst_cols_count,
365
+ cfg_val,
366
+ ),
367
+ )?;
368
+ }
369
+
370
+ // 3. Anomaly.
371
+ let anom_cfg = LaunchConfig {
372
+ grid_dim: (1, 1, 1),
373
+ block_dim: (256, 1, 1),
374
+ shared_mem_bytes: 0,
375
+ };
376
+ unsafe {
377
+ anom_fn.clone().launch(
378
+ anom_cfg,
379
+ (
380
+ sp_active_mask,
381
+ &self.unpredicted_count,
382
+ anomaly_out,
383
+ t_slot,
384
+ n_cols as u32,
385
+ ),
386
+ )?;
387
+ }
388
+
389
+ if learn {
390
+ // 4. Reinforce (grid = n_cells).
391
+ let learn_cfg = LaunchConfig {
392
+ grid_dim: (n_cells as u32, 1, 1),
393
+ block_dim: (32, 1, 1),
394
+ shared_mem_bytes: 0,
395
+ };
396
+ unsafe {
397
+ learn_fn.clone().launch(
398
+ learn_cfg,
399
+ (
400
+ &self.seg_cell_id,
401
+ &self.seg_syn_count,
402
+ &self.syn_presyn,
403
+ &mut self.syn_perm,
404
+ &self.seg_num_active_conn,
405
+ &self.prev_active_bits,
406
+ sp_active_mask,
407
+ &self.col_predicted,
408
+ &self.cell_seg_count,
409
+ cfg_val,
410
+ ),
411
+ )?;
412
+ }
413
+
414
+ // 5. Punish.
415
+ unsafe {
416
+ punish_fn.clone().launch(
417
+ learn_cfg,
418
+ (
419
+ &self.seg_cell_id,
420
+ &self.seg_syn_count,
421
+ &self.syn_presyn,
422
+ &mut self.syn_perm,
423
+ &self.seg_num_active_pot,
424
+ &self.prev_active_bits,
425
+ sp_active_mask,
426
+ &self.cell_seg_count,
427
+ cfg_val,
428
+ ),
429
+ )?;
430
+ }
431
+
432
+ // 6. Grow.
433
+ let grow_cfg = LaunchConfig {
434
+ grid_dim: (n_cols as u32, 1, 1),
435
+ block_dim: (32, 1, 1),
436
+ shared_mem_bytes: 0,
437
+ };
438
+ unsafe {
439
+ grow_fn.clone().launch(
440
+ grow_cfg,
441
+ (
442
+ &mut self.seg_cell_id,
443
+ &mut self.seg_syn_count,
444
+ &mut self.syn_presyn,
445
+ &mut self.syn_perm,
446
+ &mut self.cell_seg_count,
447
+ &self.burst_cols_flat,
448
+ &self.burst_cols_count,
449
+ &self.prev_winner_bits,
450
+ &self.prev_active_bits,
451
+ &self.col_best_match,
452
+ cfg_val,
453
+ ),
454
+ )?;
455
+ }
456
+ }
457
+
458
+ Ok(())
459
+ }
460
+ }
overlay/hydra/eval.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation: factual probes + sampled factual English scoring.
2
+
3
+ Extracted from train.py (W1 modularization). Semantics unchanged.
4
+
5
+ Perf optimizations (eval_perf_fix):
6
+ - Probe mode: single forward per prompt instead of autoregressive gen
7
+ - Batch decode: all GPU work first, all CPU decode after
8
+ - Batched factual probes: single padded forward instead of N sequential
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import re as _re
15
+
16
+ import torch
17
+
18
+ from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS
19
+
20
+ # Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
21
+ # the original autoregressive generation path.
22
+ FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
23
+
24
+ FACTUAL_EVAL = [
25
+ # Hard factual recall — requires specific knowledge memorization
26
+ ("The capital of France is", ["Paris", "paris"]),
27
+ ("Water boils at", ["100", "boiling"]),
28
+ ("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
29
+ # Easier completions — common collocations / patterns the model may pick up
30
+ ("Once upon a", ["time"]),
31
+ ("Hello, my name", ["is", "'s"]),
32
+ ("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
33
+ ("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
34
+ # Original hard ones kept for completeness
35
+ ("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
36
+ ("Two plus two equals", ["4", "four"]),
37
+ ]
38
+
39
+ _FACTUAL_PROBES = [
40
+ "The capital of France is",
41
+ "Water boils at",
42
+ "The largest planet in our solar system is",
43
+ "The speed of light is approximately",
44
+ "Shakespeare wrote",
45
+ ]
46
+
47
+
48
+ def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
49
+ """Top-5 next-token predictions for canonical factual prompts.
50
+
51
+ Batched: pads all prompts into a single forward pass instead of N
52
+ sequential passes.
53
+ """
54
+ print("\n--- Factual Probes ---")
55
+ model.eval()
56
+
57
+ # Process probes one at a time to avoid cooperative launch limit
58
+ # (batched forward with B=len(probes) can exceed SM residency cap).
59
+ for prompt_text in _FACTUAL_PROBES:
60
+ ids = tokenizer.encode(prompt_text)
61
+ x = torch.tensor([ids], device=device)
62
+ with torch.no_grad(), autocast_ctx:
63
+ logits = model(x)
64
+ probs = torch.softmax(logits[0, -1].float(), dim=-1)
65
+ top5 = torch.topk(probs, 5)
66
+ completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
67
+ probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
68
+ print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
69
+ print("--- End Factual Probes ---\n")
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Probe mode: single forward per prompt (Fix D)
74
+ # ---------------------------------------------------------------------------
75
+
76
+ def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
77
+ """Fast probe mode: for each (prompt, answers), encode prompt + each answer
78
+ candidate as a single sequence, do ONE forward pass, and check if the model's
79
+ argmax at the last prompt token matches the first answer token.
80
+
81
+ Falls back to checking top-K predictions to be generous (same as gen mode
82
+ which samples multiple temperatures).
83
+ """
84
+ print("---")
85
+ print("factual_english_samples: (probe mode)")
86
+ model.eval()
87
+ hits = 0
88
+
89
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
90
+ for prompt, answers in FACTUAL_EVAL:
91
+ prompt_ids = tokenizer.encode(prompt)
92
+ prompt_len = len(prompt_ids)
93
+ x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
94
+ logits = model(x, targets=None)
95
+ # logits shape: [1, seq_len, vocab] or [1, vocab]
96
+ if logits.dim() == 3:
97
+ last_logits = logits[0, -1, :]
98
+ else:
99
+ last_logits = logits[0]
100
+
101
+ probs = torch.softmax(last_logits.float(), dim=-1)
102
+ # Check top-K predictions (generous: K=20 to match multi-sample gen)
103
+ top_k = min(20, probs.shape[-1])
104
+ top_ids = torch.topk(probs, top_k).indices.tolist()
105
+ top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
106
+
107
+ answers_lower = [a.lower() for a in answers]
108
+ any_hit = any(
109
+ any(a in tok for a in answers_lower)
110
+ for tok in top_tokens
111
+ )
112
+ if any_hit:
113
+ hits += 1
114
+
115
+ best_completion = tokenizer.decode([top_ids[0]])
116
+ print(f" prompt: {prompt!r}")
117
+ print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
118
+ print(f" hit: {any_hit} (probe top-{top_k})")
119
+
120
+ score = hits / len(FACTUAL_EVAL)
121
+ print("---")
122
+ print(f"factual_english_score: {score:.4f}")
123
+ print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
124
+ return score, hits, len(FACTUAL_EVAL)
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Gen mode: original autoregressive path (Fix F: batch decode)
129
+ # ---------------------------------------------------------------------------
130
+
131
+ def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
132
+ """Original autoregressive generation path with batch decode optimization:
133
+ all GPU work runs first, then all CPU decoding happens after."""
134
+ print("---")
135
+ print("factual_english_samples: (gen mode)")
136
+ model.eval()
137
+
138
+ num_samples = FACTUAL_SAMPLES
139
+ batch = FACTUAL_BATCH
140
+ gen_tokens = FACTUAL_GEN_TOKENS
141
+ temps = [0.7, 0.9, 1.1]
142
+ hits = 0
143
+
144
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
145
+ for prompt, answers in FACTUAL_EVAL:
146
+ ids = tokenizer.encode(prompt)
147
+ answers_lower = [a.lower() for a in answers]
148
+ # Collect all generated token sequences on GPU first
149
+ all_rows: list[list[int]] = []
150
+ samples_done = 0
151
+ batch_idx = 0
152
+ while samples_done < num_samples:
153
+ b = min(batch, num_samples - samples_done)
154
+ temp = temps[batch_idx % len(temps)]
155
+ batch_idx += 1
156
+ ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
157
+ for _ in range(gen_tokens):
158
+ logits = model(ctx, targets=None)
159
+ next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
160
+ probs = torch.softmax(next_logits.float() / temp, dim=-1)
161
+ next_id = torch.multinomial(probs, num_samples=1)
162
+ ctx = torch.cat([ctx, next_id], dim=1)
163
+ if ctx.size(1) >= max_seq_len:
164
+ break
165
+ # Transfer to CPU in one shot, no per-row sync
166
+ all_rows.extend(ctx.cpu().tolist())
167
+ samples_done += b
168
+
169
+ # CPU-side batch decode — no GPU sync between decodes
170
+ any_hit = False
171
+ first_gen = None
172
+ hit_gen = None
173
+ for row in all_rows:
174
+ generated = tokenizer.decode(row)
175
+ continuation = generated[len(prompt):].strip()
176
+ _words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
177
+ hit = any(a in _words for a in answers_lower)
178
+ if first_gen is None:
179
+ first_gen = generated
180
+ if hit:
181
+ any_hit = True
182
+ if hit_gen is None:
183
+ hit_gen = generated
184
+ if any_hit:
185
+ hits += 1
186
+ print(f" prompt: {prompt!r}")
187
+ print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}")
188
+ print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
189
+ if hit_gen is not None and hit_gen != first_gen:
190
+ print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}")
191
+
192
+ score = hits / len(FACTUAL_EVAL)
193
+ print("---")
194
+ print(f"factual_english_score: {score:.4f}")
195
+ print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
196
+ return score, hits, len(FACTUAL_EVAL)
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Public entry point
201
+ # ---------------------------------------------------------------------------
202
+
203
+ def run_factual_english(model, tokenizer, max_seq_len: int):
204
+ """Dispatch to probe (fast, default) or gen (original) mode.
205
+
206
+ Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
207
+ """
208
+ if FACTUAL_MODE == "gen":
209
+ return _run_factual_english_gen(model, tokenizer, max_seq_len)
210
+ return _run_factual_english_probe(model, tokenizer, max_seq_len)
overlay/hydra/model.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PostSemClawModel — full-architecture model assembly.
2
+
3
+ Extracted from the monolithic train.py (W1 modularization). Semantics
4
+ unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from
5
+ `hydra.optimizer`.
6
+
7
+ Triton kernel integration status (Phase 2):
8
+ HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses
9
+ LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block
10
+ uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside
11
+ the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing
12
+ would require either (a) monkey-patching RMSNormGated + intercepting the
13
+ fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a
14
+ full custom Mamba3Block reimplementation. Both are out of scope for
15
+ Phase 2. The kernel is validated standalone; integration deferred to
16
+ Phase 3 when HYDRA moves to a custom SSM block.
17
+
18
+ HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements
19
+ exponential-trapezoidal discretization as a sequential scan. mamba-ssm's
20
+ Mamba3 block delegates the entire scan + gating + output projection to
21
+ mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing
22
+ it would require decomposing the combined kernel into constituent ops
23
+ and substituting only the scan — not feasible without a custom block.
24
+ Same Phase 3 gate as above.
25
+
26
+ Both env vars are accepted but currently no-ops (gates read, logged, but
27
+ the code path is unchanged). This avoids silent regression if someone
28
+ sets them expecting a speedup.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import os
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from mamba_ssm import Mamba3
40
+
41
+ from subsystems.hestia_mini import HestiaQAT
42
+ from subsystems.htm import HTMLayer
43
+ from subsystems.mhc_mini import ManifoldHyperConnection
44
+ from subsystems.sdr_semantic import SemanticFoldingSDR
45
+
46
+ from hydra.engram import GPUEngram
47
+ from hydra.optimizer import MuonAdamW
48
+
49
+
50
+ def norm(x: torch.Tensor) -> torch.Tensor:
51
+ """RMSNorm over the last dim — stateless, autocast-friendly."""
52
+ return F.rms_norm(x, (x.size(-1),))
53
+
54
+
55
+ class PostSemClawModel(nn.Module):
56
+ """Full Post-SEM-Claw model assembly.
57
+
58
+ Architecture:
59
+ Token Embedding -> [Mamba3 + residual] x n_layer
60
+ -> SDR + Engram (at configured layer) -> norm -> LM head
61
+
62
+ Interface (must match prepare.py evaluate_bpb):
63
+ model(x, y, reduction='none').view(-1) -> per-token losses
64
+ model(x, y, reduction='mean') -> scalar loss
65
+ """
66
+
67
+ def __init__(self, config):
68
+ super().__init__()
69
+ self.config = config
70
+
71
+ # Token embedding
72
+ self.wte = nn.Embedding(config.vocab_size, config.d_model)
73
+
74
+ # Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks.
75
+ # RoPE is applied internally by the Mamba3 CUDA kernel via the Angles
76
+ # parameter; external cos/sin buffers are not needed.
77
+ self.blocks = nn.ModuleList([
78
+ Mamba3(
79
+ d_model=config.d_model,
80
+ d_state=config.d_state,
81
+ expand=config.expand,
82
+ headdim=config.headdim,
83
+ is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel
84
+ chunk_size=64, # upstream-recommended SISO chunk; 16 violated tl.dot M>=16 constraint
85
+ is_outproj_norm=False,
86
+ dtype=torch.bfloat16,
87
+ )
88
+ for _ in range(config.n_layer)
89
+ ])
90
+
91
+ # Full-architecture SDR: offline semantic retina + STE (no-bypass).
92
+ self.sdr_semantic = SemanticFoldingSDR(
93
+ vocab_size=config.vocab_size,
94
+ n_bits=config.sdr_n_bits,
95
+ target_active=config.sdr_target_active,
96
+ delta_rank=config.sdr_delta_rank,
97
+ som_warmup_steps=config.sdr_som_warmup,
98
+ som_update_interval=config.sdr_som_interval,
99
+ )
100
+
101
+ # HTM spatial pooler + temporal memory (Rust, Hebbian).
102
+ self.htm = HTMLayer(
103
+ input_bits=config.sdr_n_bits,
104
+ n_columns=config.htm_n_columns,
105
+ cells_per_column=config.htm_cells_per_column,
106
+ batch_size=1, # grows lazily to actual B on first forward
107
+ seed=42,
108
+ learn=True,
109
+ reset_each_forward=True,
110
+ )
111
+
112
+ # Gradient bridge: (n_columns + anomaly) -> d_model.
113
+ self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False)
114
+
115
+ # GPU Engram with Hebbian writes — runs EVERY step.
116
+ self.engram = GPUEngram(
117
+ d_model=config.d_model,
118
+ n_columns=config.engram_n_columns,
119
+ max_ngram=3,
120
+ )
121
+ self.engram_layer_idx = config.engram_layer_idx
122
+
123
+ # Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
124
+ self.mhc = nn.ModuleList([
125
+ ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3)
126
+ for _ in range(config.n_layer)
127
+ ])
128
+
129
+ # Hestia QAT — ternary weight quantization applied post-optimizer-step.
130
+ self.hestia = HestiaQAT(enabled=True, bits=1.58)
131
+
132
+ # LM head
133
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
134
+
135
+ # Residual dropout
136
+ self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.2")))
137
+
138
+ # Logits soft-capping
139
+ self.softcap = 15.0
140
+
141
+ # Secondary metrics storage
142
+ self._metrics = {}
143
+
144
+ # Per-layer diagnostic panel. Env-gated; zero overhead when off.
145
+ # Emits residual-contribution (delta_ratio), feature std, effective rank,
146
+ # gradient norm per layer; used to identify minimum viable n_layer + find
147
+ # entropy leakage / dead layers. See docs/depth-sweep.md.
148
+ self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1"
149
+ self._diag_step = 0
150
+ self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100"))
151
+ if self._diag_enabled:
152
+ # Gradient-norm backward hooks on each Mamba3 block output.
153
+ for _i, _block in enumerate(self.blocks):
154
+ def _mk_grad_hook(_layer_idx):
155
+ def _hook(module, grad_input, grad_output):
156
+ if grad_output and grad_output[0] is not None:
157
+ g = grad_output[0].detach()
158
+ self._metrics[f'layer_{_layer_idx}_grad_norm'] = float(
159
+ g.pow(2).mean().sqrt().item()
160
+ )
161
+ return _hook
162
+ _block.register_full_backward_hook(_mk_grad_hook(_i))
163
+
164
+ # Forward hooks on each Mamba3 block capture the block's OUTPUT
165
+ # directly. This is the clean measurement: unlike merge_streams()
166
+ # sampling which sees (streams + M*block_output) in bf16 — where
167
+ # small block contributions round to zero against unit-norm
168
+ # residuals — this captures `block_output` itself as produced.
169
+ # Reports both its absolute RMS norm and its ratio to the block
170
+ # INPUT's RMS norm (contribution magnitude relative to the
171
+ # residual it's added to).
172
+ for _i, _block in enumerate(self.blocks):
173
+ def _mk_fwd_hook(_layer_idx):
174
+ def _hook(module, inputs, output):
175
+ with torch.no_grad():
176
+ inp = inputs[0].detach().float() if inputs else None
177
+ out = output.detach().float() if isinstance(output, torch.Tensor) else None
178
+ if out is not None:
179
+ out_rms = out.pow(2).mean().sqrt().item()
180
+ self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms)
181
+ if inp is not None:
182
+ in_rms = inp.pow(2).mean().sqrt().item()
183
+ self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms)
184
+ self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float(
185
+ out_rms / (in_rms + 1e-8)
186
+ )
187
+ return _hook
188
+ _block.register_forward_hook(_mk_fwd_hook(_i))
189
+
190
+ # Triton kernel integration gates (Phase 2 — deferred, see module docstring).
191
+ self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1"
192
+ self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1"
193
+ if self._fused_bcnorm or self._fused_ssd:
194
+ import sys
195
+ _active = []
196
+ if self._fused_bcnorm:
197
+ _active.append("HYDRA_FUSED_BCNORM")
198
+ if self._fused_ssd:
199
+ _active.append("HYDRA_FUSED_SSD")
200
+ print(
201
+ f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. "
202
+ f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal "
203
+ f"CUDA kernels). Gates accepted but currently no-ops.",
204
+ file=sys.stderr,
205
+ )
206
+
207
+ # R6 optional torch.compile on the impl forward. Gated (default OFF).
208
+ if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1":
209
+ self._forward_impl = torch.compile(
210
+ self._forward_impl,
211
+ fullgraph=False,
212
+ dynamic=True,
213
+ mode="default",
214
+ )
215
+
216
+ @torch.no_grad()
217
+ def init_weights(self) -> None:
218
+ s = 3 ** 0.5 * self.config.d_model ** -0.5
219
+
220
+ # Move SDR retina indices (plain attribute, not buffer) to same device as params.
221
+ # Required because to_empty() only moves params/buffers, and _retina_indices
222
+ # is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
223
+ device = self.wte.weight.device
224
+ if hasattr(self.sdr_semantic, '_retina_indices'):
225
+ self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
226
+
227
+ # Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
228
+ # vocab=8192; at larger vocabs, smaller std prevents logit blowup.
229
+ # Use std = 1/sqrt(d_model) which scales sensibly with model width.
230
+ import math as _math
231
+ _d_model = self.wte.weight.shape[1]
232
+ wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model))))
233
+ nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std)
234
+ # LM head init: was std=0.001 — PATHOLOGICAL at vocab>=32k because
235
+ # logits collapse to zero, loss locks at log(V)~=11, gradient through
236
+ # head ∝ 1/V is too small to escape. GPT-2 uses std=0.02; LLaMA uses
237
+ # std=1/sqrt(d_model). Pick 0.02 as robust default, env-overridable.
238
+ lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02"))
239
+ nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
240
+ # F8 (NOT APPLIED): Weight tying would save V*D params but current LR
241
+ # groups have embedding_lr=1.0 and unembedding_lr=0.005 × d_model_scale
242
+ # — tying forces the shared tensor under a single LR group and either
243
+ # the embeddings learn 200x too slow (under unembed LR) or the LM head
244
+ # becomes unstable (under embed LR). Short 15-step smoke with tying +
245
+ # embed-group update showed initial loss jump 9 -> 20. Deferred until
246
+ # LR groups are re-tuned; see docs/OPTIMIZATION_PLAN.md Post-plan.
247
+
248
+ for li, block in enumerate(self.blocks):
249
+ if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'):
250
+ nn.init.uniform_(block.in_proj.weight, -s, s)
251
+ if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'):
252
+ # GPT-2 residual init: std = 0.02 / sqrt(2 * n_layer).
253
+ # NOT zeros — zero init makes the block a permanent pass-through
254
+ # (block_out_rms=0, zero gradient flow to SSM internals).
255
+ # With non-zero init the block contributes to the residual stream
256
+ # from step 1, giving the SSM scan actual gradient signal.
257
+ n_layer = self.config.n_layer
258
+ out_std = float(os.environ.get(
259
+ "HYDRA_OUT_PROJ_STD",
260
+ str(0.02 / (2 * n_layer) ** 0.5),
261
+ ))
262
+ nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
263
+
264
+ nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
265
+
266
+ # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
267
+ # dtypes in the same shape group would break lerp_ dtype checks.
268
+ self.wte.to(dtype=torch.bfloat16)
269
+ self.htm_proj.to(dtype=torch.bfloat16)
270
+ self.engram.to(dtype=torch.bfloat16)
271
+
272
+ def estimate_flops(self) -> int:
273
+ nparams = sum(p.numel() for p in self.parameters())
274
+ embed_params = self.wte.weight.numel()
275
+ return 6 * (nparams - embed_params)
276
+
277
+ def num_scaling_params(self) -> dict:
278
+ wte = sum(p.numel() for p in self.wte.parameters())
279
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
280
+ blocks = sum(p.numel() for p in self.blocks.parameters())
281
+ sdr = sum(p.numel() for p in self.sdr_semantic.parameters())
282
+ htm_proj = sum(p.numel() for p in self.htm_proj.parameters())
283
+ engram = sum(p.numel() for p in self.engram.parameters())
284
+ total = sum(p.numel() for p in self.parameters())
285
+ return {
286
+ 'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
287
+ 'sdr_semantic': sdr, 'htm_proj': htm_proj,
288
+ 'engram': engram, 'total': total,
289
+ }
290
+
291
+ def get_secondary_metrics(self) -> dict:
292
+ """Flush any lingering CUDA tensors to host (single sync)."""
293
+ flushed = {}
294
+ for k, v in self._metrics.items():
295
+ if hasattr(v, 'item'):
296
+ try:
297
+ flushed[k] = float(v.item())
298
+ except Exception:
299
+ flushed[k] = v
300
+ else:
301
+ flushed[k] = v
302
+ return flushed
303
+
304
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04,
305
+ weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5):
306
+ """Setup MuonAdamW optimizer with per-component LR groups."""
307
+ model_dim = self.config.d_model
308
+
309
+ embedding_params = list(self.wte.parameters())
310
+ lm_head_params = list(self.lm_head.parameters())
311
+
312
+ # Matrix params -> Muon (exactly 2D weight matrices).
313
+ matrix_params = []
314
+ for p in self.blocks.parameters():
315
+ if p.dim() == 2:
316
+ matrix_params.append(p)
317
+ # NOTE (W1 audit REG-2): SemanticFoldingSDR.delta_u / delta_v are
318
+ # currently GRADIENT-DEAD. The forward path uses `binary_only(idx)` for
319
+ # HTM and stores it as `self._last_sdr`, but does NOT route the STE
320
+ # output through any downstream op. Including them in the Muon group
321
+ # burns compute (stack + orthogonalize + lerp) on zero-grad params
322
+ # every step. Excluded here; a later W5 pass can reconnect STE via a
323
+ # gated residual if the SDR signal is wanted back in-graph. The
324
+ # parameters still exist, so no state_dict break.
325
+ # for p in self.sdr_semantic.parameters():
326
+ # if p.dim() == 2:
327
+ # matrix_params.append(p)
328
+ for p in self.htm_proj.parameters():
329
+ if p.dim() == 2:
330
+ matrix_params.append(p)
331
+ for p in self.engram.parameters():
332
+ if p.dim() == 2:
333
+ matrix_params.append(p)
334
+
335
+ # SDR params are intentionally not in any optimizer group — they
336
+ # receive no gradient in the current forward, so any update would be
337
+ # pure noise (weight_decay × lr on a zero-grad param).
338
+ sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters())
339
+ assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
340
+ scalar_params = [
341
+ p for p in self.parameters()
342
+ if id(p) not in assigned and id(p) not in sdr_param_ids
343
+ ]
344
+
345
+ total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params)
346
+ total_params = len(list(self.parameters()))
347
+ sdr_excluded = len(list(self.sdr_semantic.parameters()))
348
+ assert total_assigned + sdr_excluded == total_params, (
349
+ f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
350
+ f"{sdr_excluded} vs total {total_params}"
351
+ )
352
+
353
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
354
+ print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
355
+
356
+ param_groups = [
357
+ dict(kind='adamw', params=lm_head_params,
358
+ lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas,
359
+ eps=1e-10, weight_decay=0.0),
360
+ dict(kind='adamw', params=embedding_params,
361
+ lr=embedding_lr * dmodel_lr_scale, betas=adam_betas,
362
+ eps=1e-10, weight_decay=0.0),
363
+ ]
364
+
365
+ if scalar_params:
366
+ param_groups.append(
367
+ dict(kind='adamw', params=scalar_params,
368
+ lr=scalar_lr * dmodel_lr_scale, betas=adam_betas,
369
+ eps=1e-10, weight_decay=0.0)
370
+ )
371
+
372
+ for shape in sorted({p.shape for p in matrix_params}):
373
+ group_params = [p for p in matrix_params if p.shape == shape]
374
+ param_groups.append(dict(
375
+ kind='muon', params=group_params, lr=matrix_lr,
376
+ momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
377
+ ))
378
+
379
+ optimizer = MuonAdamW(param_groups)
380
+ for group in optimizer.param_groups:
381
+ group["initial_lr"] = group["lr"]
382
+ return optimizer
383
+
384
+ def forward(self, idx, targets=None, reduction='mean'):
385
+ """idx: (B, T) int64. Returns loss if targets given, else logits.
386
+
387
+ Nested bf16 autocast is a no-op when ambient autocast is already on;
388
+ when it's off (e.g. integration tests) we establish the dtype contract.
389
+ """
390
+ if torch.is_autocast_enabled():
391
+ return self._forward_impl(idx, targets=targets, reduction=reduction)
392
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
393
+ return self._forward_impl(idx, targets=targets, reduction=reduction)
394
+
395
+ def _forward_impl(self, idx, targets=None, reduction='mean'):
396
+ B, T = idx.shape
397
+
398
+ # Diagnostic: per-subsystem CUDA event timing. Env-gated; zero overhead
399
+ # when disabled. Logs one timing line per forward call. Used to isolate
400
+ # which subsystem is the tps bottleneck on paid hardware.
401
+ _profile = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
402
+ if _profile:
403
+ def _ev():
404
+ e = torch.cuda.Event(enable_timing=True)
405
+ e.record()
406
+ return e
407
+ _t0 = _ev()
408
+ else:
409
+ _t0 = None
410
+
411
+ # Compute SDR binary ONCE and reuse for both HTM input and the stash.
412
+ sdr_binary = self.sdr_semantic.binary_only(idx)
413
+ self._last_sdr = sdr_binary # uint8 stash (not bf16 → 256MB avoidance)
414
+
415
+ # HTM subsampling: run HTM on 1 of every N micro-batches within a
416
+ # gradient accumulation step, reuse the cached result for the other
417
+ # N-1 micro-batches. Cooperative launch monopolizes all SMs (grid.sync
418
+ # requires full-grid residency), so HTM and mamba can't overlap via
419
+ # streams. Subsampling removes HTM from most micro-batches' critical
420
+ # path instead.
421
+ #
422
+ # Math: N=8, 64 accum steps → 8 HTM calls (10.6ms each) + 56 fast
423
+ # calls (4ms each). Total = 84.8 + 224 = 309ms → 106k tps.
424
+ #
425
+ # HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
426
+ _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
427
+ if not hasattr(self, '_htm_call_idx'):
428
+ self._htm_call_idx = 0
429
+
430
+ _run_htm = (self._htm_call_idx % _htm_sub == 0)
431
+ self._htm_call_idx += 1
432
+
433
+ if _run_htm:
434
+ htm_handle = self.htm.forward_async(sdr_binary)
435
+ else:
436
+ htm_handle = None
437
+
438
+ if _profile: _t_htm_async = _ev()
439
+
440
+ dense_emb = self.wte(idx) # (B, T, d_model) bf16
441
+
442
+ if _profile: _t_wte = _ev()
443
+
444
+ if _run_htm:
445
+ htm_out = self.htm.forward_await(htm_handle)
446
+ self._htm_cache = htm_out.detach() # cache for non-HTM micro-batches
447
+ elif hasattr(self, '_htm_cache') and self._htm_cache is not None \
448
+ and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
449
+ htm_out = self._htm_cache
450
+ else:
451
+ # Very first call with subsample > 1: run HTM anyway.
452
+ htm_handle = self.htm.forward_async(sdr_binary)
453
+ htm_out = self.htm.forward_await(htm_handle)
454
+ self._htm_cache = htm_out.detach()
455
+
456
+ if _profile: _t_htm_await = _ev()
457
+ with torch.no_grad():
458
+ sdr_active_bits = float(self.sdr_semantic.target_active)
459
+ htm_anomaly = htm_out[..., -1].mean()
460
+
461
+ # Gradient bridge: HTM columns+anomaly -> d_model.
462
+ htm_proj_out = self.htm_proj(htm_out.to(dense_emb.dtype))
463
+ x = dense_emb + htm_proj_out
464
+ x = norm(x)
465
+
466
+ if _profile: _t_htm_proj = _ev()
467
+
468
+ # mHC-routed Mamba-3 stack with Engram injection at configured layer.
469
+ streams = self.mhc[0].init_streams(x)
470
+ _engram_ev = None
471
+
472
+ # Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
473
+ # measure residual contribution of each layer: delta_N = h_post - h_pre.
474
+ # All reads are detached no-grad to avoid autograd graph pollution.
475
+ _diag = self._diag_enabled
476
+ if _diag:
477
+ # Cast to float32 for the diagnostic arithmetic: the layer's
478
+ # residual contribution is small (~0.5 × rms-normed block output),
479
+ # which underflows in bf16 subtraction (3-digit mantissa) and
480
+ # reports delta_ratio=0 at the boundaries. float32 snapshot is
481
+ # ~3.8 MB extra memory per diag sample (B=1, T=2048, d=96) —
482
+ # negligible vs peak VRAM.
483
+ with torch.no_grad():
484
+ h_pre = self.mhc[0].merge_streams(streams).detach().float()
485
+ _run_svd = (self._diag_step % self._diag_svd_every) == 0
486
+
487
+ for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
488
+ def _block_fn(h, _block=block):
489
+ return self.drop(_block(norm(h)))
490
+
491
+ streams = mhc_layer(streams, _block_fn)
492
+
493
+ if i == self.engram_layer_idx:
494
+ if _profile: _t_pre_engram = _ev()
495
+ x_mid = mhc_layer.merge_streams(streams)
496
+ x_mid, hit_rate = self.engram(x_mid, idx)
497
+ streams = mhc_layer.init_streams(x_mid)
498
+ self._metrics['engram_hit_rate'] = hit_rate
499
+ if _profile: _engram_ev = _ev()
500
+
501
+ if _diag:
502
+ with torch.no_grad():
503
+ h_post = mhc_layer.merge_streams(streams).detach().float()
504
+ in_n = h_pre.pow(2).mean().sqrt()
505
+ out_n = h_post.pow(2).mean().sqrt()
506
+ d_n = (h_post - h_pre).pow(2).mean().sqrt()
507
+ self._metrics[f'layer_{i}_in_norm'] = float(in_n.item())
508
+ self._metrics[f'layer_{i}_out_norm'] = float(out_n.item())
509
+ self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item())
510
+ self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item())
511
+ if _run_svd:
512
+ # Effective rank via participation ratio of singular values.
513
+ # eff_rank = (Σσ)^2 / Σσ² — smooth rank proxy, bounded by d_model.
514
+ # Sampled to keep overhead low (SVD is O(min(B*T, D)^2·D)).
515
+ flat = h_post.reshape(-1, h_post.shape[-1])[:512].float()
516
+ try:
517
+ s = torch.linalg.svdvals(flat)
518
+ eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item())
519
+ self._metrics[f'layer_{i}_eff_rank'] = eff_rank
520
+ except Exception:
521
+ pass
522
+ h_pre = h_post
523
+
524
+ if _diag:
525
+ self._diag_step += 1
526
+
527
+ if _profile: _t_blocks = _ev()
528
+
529
+ self._metrics['sdr_active_bits'] = sdr_active_bits
530
+ self._metrics['htm_anomaly'] = htm_anomaly
531
+
532
+ x = self.mhc[-1].merge_streams(streams)
533
+ x = norm(x)
534
+
535
+ if _profile: _t_merge = _ev()
536
+
537
+ softcap = self.softcap
538
+ _softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1"
539
+ if targets is not None:
540
+ smoothing = self.config.label_smoothing
541
+ V = self.config.vocab_size
542
+
543
+ # Sampled softmax: instead of computing logits for ALL V tokens,
544
+ # compute only for the target + K random negatives. Reduces the
545
+ # lm_head matmul from (B*T, d) × (d, V) to (B*T, d) × (d, K+1).
546
+ # At V=65536 and K=4096: 16× less compute, ~4× tps improvement.
547
+ # The log-sum-exp correction adjusts for the sampling bias.
548
+ # Set HYDRA_SAMPLED_SOFTMAX=0 to disable (full softmax).
549
+ K_neg = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096"))
550
+ use_sampled = K_neg > 0 and K_neg < V and self.training
551
+
552
+ if use_sampled:
553
+ # Flatten hidden states + targets
554
+ h_flat = x.reshape(-1, x.shape[-1]) # (B*T, d)
555
+ t_flat = targets.reshape(-1) # (B*T,)
556
+ n = h_flat.shape[0]
557
+
558
+ # Sample K negatives uniformly from [0, V)
559
+ neg_ids = torch.randint(0, V, (K_neg,), device=x.device)
560
+ # Gather lm_head weights for target + negatives
561
+ all_ids = torch.cat([t_flat, neg_ids]) # (B*T + K,)
562
+ sampled_w = self.lm_head.weight[all_ids] # (B*T + K, d)
563
+
564
+ # Compute sampled logits: for each position, dot with its
565
+ # target weight and all K negative weights.
566
+ # Target logit: dot product of h[i] with w[target[i]]
567
+ target_w = sampled_w[:n] # (B*T, d)
568
+ neg_w = sampled_w[n:] # (K, d)
569
+ target_logit = (h_flat * target_w).sum(-1) # (B*T,)
570
+ neg_logits = h_flat @ neg_w.t() # (B*T, K)
571
+
572
+ if not _softcap_clamp:
573
+ target_logit = softcap * torch.tanh(target_logit / softcap)
574
+ neg_logits = softcap * torch.tanh(neg_logits / softcap)
575
+
576
+ # Sampled softmax loss: -log(exp(target) / (exp(target) + sum(exp(neg))))
577
+ # With log-sum-exp correction for sampling K of V negatives.
578
+ # Correction: add log(V/K) to negative logits to account for
579
+ # the fact that we're only seeing K of V possible negatives.
580
+ log_correction = torch.tensor(V / K_neg, device=x.device).log()
581
+ all_logits = torch.cat([
582
+ target_logit.unsqueeze(-1), # (B*T, 1)
583
+ neg_logits + log_correction, # (B*T, K)
584
+ ], dim=-1).float() # (B*T, K+1)
585
+
586
+ # CE with target always at index 0
587
+ ce_targets = torch.zeros(n, dtype=torch.long, device=x.device)
588
+ if reduction == 'none':
589
+ return F.cross_entropy(all_logits, ce_targets, reduction='none')
590
+ out = F.cross_entropy(all_logits, ce_targets, reduction='mean',
591
+ label_smoothing=smoothing)
592
+ else:
593
+ # Full softmax path (eval or HYDRA_SAMPLED_SOFTMAX=0)
594
+ chunk_size = int(os.environ.get("HYDRA_CE_CHUNK", "1024"))
595
+ if chunk_size <= 0:
596
+ MAX_LOGITS_BYTES = 256 * 1024 * 1024
597
+ tokens_per_chunk = max(V, MAX_LOGITS_BYTES // (V * 4))
598
+ chunk_size = max(1, tokens_per_chunk // max(1, B))
599
+ chunk_size = min(chunk_size, T)
600
+
601
+ if reduction == 'none':
602
+ loss_parts = []
603
+ for start in range(0, T, chunk_size):
604
+ end = min(start + chunk_size, T)
605
+ chunk_logits = self.lm_head(x[:, start:end, :]).float()
606
+ if _softcap_clamp:
607
+ chunk_logits = torch.clamp(chunk_logits, -softcap, softcap)
608
+ else:
609
+ chunk_logits = softcap * torch.tanh(chunk_logits / softcap)
610
+ chunk_targets = targets[:, start:end].reshape(-1)
611
+ chunk_loss = F.cross_entropy(
612
+ chunk_logits.view(-1, chunk_logits.size(-1)),
613
+ chunk_targets, ignore_index=-1, reduction='none',
614
+ )
615
+ loss_parts.append(chunk_loss)
616
+ return torch.cat(loss_parts)
617
+
618
+ total_loss = 0.0
619
+ total_tokens = 0
620
+ for start in range(0, T, chunk_size):
621
+ end = min(start + chunk_size, T)
622
+ chunk_logits = self.lm_head(x[:, start:end, :]).float()
623
+ if _softcap_clamp:
624
+ chunk_logits = torch.clamp(chunk_logits, -softcap, softcap)
625
+ else:
626
+ chunk_logits = softcap * torch.tanh(chunk_logits / softcap)
627
+ chunk_targets = targets[:, start:end].reshape(-1)
628
+ chunk_loss = F.cross_entropy(
629
+ chunk_logits.view(-1, chunk_logits.size(-1)),
630
+ chunk_targets, ignore_index=-1, reduction='sum',
631
+ label_smoothing=smoothing,
632
+ )
633
+ total_loss = total_loss + chunk_loss
634
+ total_tokens += (chunk_targets != -1).sum()
635
+ out = total_loss / total_tokens
636
+ if _profile:
637
+ _t_end = _ev()
638
+ torch.cuda.synchronize()
639
+ def _ms(a, b): return a.elapsed_time(b)
640
+ print(
641
+ f"[PROFILE B={B} T={T}] "
642
+ f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
643
+ f"wte={_ms(_t_htm_async, _t_wte):.2f} "
644
+ f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
645
+ f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} "
646
+ f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} "
647
+ f"merge={_ms(_t_blocks, _t_merge):.2f} "
648
+ f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
649
+ f"total={_ms(_t0, _t_end):.2f} ms",
650
+ flush=True,
651
+ )
652
+ return out
653
+
654
+ logits = self.lm_head(x).float()
655
+ if _softcap_clamp:
656
+ logits = torch.clamp(logits, -softcap, softcap)
657
+ else:
658
+ logits = softcap * torch.tanh(logits / softcap)
659
+ return logits
overlay/hydra/optimizer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
2
+
3
+ Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
4
+
5
+ F1-F15 state preserved:
6
+ - F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
7
+ step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
8
+ fresh. Persistent copies of param storage would be mutated between forward
9
+ passes (via lerp_/sub_ on stacked tensors that share storage with params),
10
+ triggering "modified in-place" errors on grad_accum=2 backwards.
11
+ - F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
12
+ - F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
13
+ dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
14
+ stream-capture deadlock. See .omc/muon_compile_bug.md for the full
15
+ investigation.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+
22
+ import torch
23
+
24
+ # HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
25
+ _HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
26
+ _HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
27
+
28
+
29
+ polar_express_coeffs = [
30
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
31
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
32
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
33
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
34
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
35
+ ]
36
+
37
+
38
+ def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
39
+ # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
40
+ # for the whole group) driven from MuonAdamW._step_adamw below.
41
+ grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
42
+ p.mul_(1 - lr_t * wd_t)
43
+ exp_avg.lerp_(grad, 1 - beta1_t)
44
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
45
+ bias1 = 1 - beta1_t ** step_t
46
+ bias2 = 1 - beta2_t ** step_t
47
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
48
+ step_size = lr_t / bias1
49
+ p.add_(exp_avg / denom, alpha=-step_size)
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # F15 muon_step_fused compile strategy.
54
+ #
55
+ # HYDRA_MUON_COMPILE env gate:
56
+ # "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
57
+ # Dynamic=True collapses the per-shape specialization cache so that N
58
+ # Muon param-groups with N distinct shapes trigger 1 compile, not N.
59
+ # mode="default" keeps the inductor codegen but disables cudagraphs,
60
+ # which is what caused the step-17→18 silent deadlock observed under
61
+ # the original dynamic=False configuration: cudagraph stream capture
62
+ # can deadlock against HTM's CUDA kernels running on the default
63
+ # stream, and the failure mode at capture-time is a silent hang
64
+ # (100% GPU util, no log output, process state R).
65
+ # "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
66
+ # Keeps an escape hatch in case a future torch/inductor regression
67
+ # reintroduces a deadlock.
68
+ #
69
+ # Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
70
+ # alias-analysis edge case where inductor sees `g is stacked_grads` and
71
+ # subsequent `stacked_grads.square()` operating on the post-lerp storage.
72
+ # ---------------------------------------------------------------------------
73
+ _MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
74
+
75
+ def _maybe_compile(fn):
76
+ if _MUON_COMPILE:
77
+ # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
78
+ # would enable) to avoid stream-capture deadlocks against HTM's CUDA
79
+ # kernels. dynamic=True minimizes recompile count across param-group
80
+ # shapes.
81
+ return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
82
+ return fn
83
+
84
+ @_maybe_compile
85
+ def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
86
+ momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
87
+ # Cast grads to param dtype AND clone defensively to break any alias
88
+ # between the (freshly-stacked) input and the in-place lerp_ below.
89
+ # Without this, inductor's alias analysis can emit code that reads from
90
+ # post-mutation storage when computing `v_mean = g.square().mean(...)`.
91
+ stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
92
+ # Nesterov momentum
93
+ momentum = momentum_t.to(stacked_grads.dtype)
94
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
95
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
96
+ # Polar express orthogonalization
97
+ X = g.bfloat16()
98
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
99
+ if g.size(-2) > g.size(-1):
100
+ for a, b, c in polar_express_coeffs[:ns_steps]:
101
+ A = X.mT @ X
102
+ B = b * A + c * (A @ A)
103
+ X = a * X + X @ B
104
+ else:
105
+ for a, b, c in polar_express_coeffs[:ns_steps]:
106
+ A = X @ X.mT
107
+ B = b * A + c * (A @ A)
108
+ X = a * X + B @ X
109
+ g = X
110
+ # NorMuon variance reduction
111
+ # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
112
+ # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
113
+ beta2 = beta2_t.to(second_momentum_buffer.dtype)
114
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
115
+ red_dim_size = g.size(red_dim)
116
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
117
+ v_norm = v_norm_sq.sqrt()
118
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
119
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
120
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
121
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
122
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
123
+ g = g * final_scale.to(g.dtype)
124
+ # Cautious weight decay + parameter update
125
+ lr = lr_t.to(g.dtype)
126
+ wd = wd_t.to(g.dtype)
127
+ mask = (g * stacked_params) >= 0
128
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
129
+
130
+
131
+ class MuonAdamW(torch.optim.Optimizer):
132
+ """Combined optimizer: Muon for 2D matrix params, AdamW for others."""
133
+
134
+ def __init__(self, param_groups):
135
+ super().__init__(param_groups, defaults={})
136
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
137
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
138
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
139
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
140
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
141
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
142
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
143
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
144
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
145
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
146
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
147
+
148
+ def _step_adamw(self, group):
149
+ params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
150
+ for p in group['params']:
151
+ if p.grad is None:
152
+ continue
153
+ state = self.state[p]
154
+ if not state:
155
+ state['step'] = 0
156
+ state['exp_avg'] = torch.zeros_like(p)
157
+ state['exp_avg_sq'] = torch.zeros_like(p)
158
+ if 'step_t' not in state:
159
+ # _fused_adamw_ wants a per-param float step tensor on-device.
160
+ state['step_t'] = torch.tensor(
161
+ float(state['step']), dtype=torch.float32, device=p.device
162
+ )
163
+ state['step'] += 1
164
+ params.append(p)
165
+ grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
166
+ exp_avgs.append(state['exp_avg'])
167
+ exp_avg_sqs.append(state['exp_avg_sq'])
168
+ state_steps.append(state['step_t'])
169
+
170
+ if not params:
171
+ return
172
+
173
+ if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
174
+ # _fused_adamw_ needs uniform (device, dtype) within a call, so
175
+ # group by (device, dtype) — same pattern as PyTorch's own
176
+ # AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
177
+ buckets = {}
178
+ for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
179
+ key = (p.device, p.dtype)
180
+ buckets.setdefault(key, ([], [], [], [], []))
181
+ b_p, b_g, b_ea, b_es, b_st = buckets[key]
182
+ b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
183
+
184
+ lr_f = float(group['lr'])
185
+ b1_f = float(group['betas'][0])
186
+ b2_f = float(group['betas'][1])
187
+ wd_f = float(group['weight_decay'])
188
+ eps_f = float(group['eps'])
189
+ for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
190
+ torch._foreach_add_(b_st, 1.0)
191
+ torch._fused_adamw_(
192
+ b_p, b_g, b_ea, b_es,
193
+ [], # max_exp_avg_sqs unused (amsgrad=False)
194
+ b_st,
195
+ amsgrad=False,
196
+ lr=lr_f, beta1=b1_f, beta2=b2_f,
197
+ weight_decay=wd_f, eps=eps_f,
198
+ maximize=False,
199
+ grad_scale=None, found_inf=None,
200
+ )
201
+ return
202
+
203
+ # Fallback per-param path.
204
+ self._adamw_lr_t.fill_(group['lr'])
205
+ self._adamw_beta1_t.fill_(group['betas'][0])
206
+ self._adamw_beta2_t.fill_(group['betas'][1])
207
+ self._adamw_eps_t.fill_(group['eps'])
208
+ self._adamw_wd_t.fill_(group['weight_decay'])
209
+ for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
210
+ self._adamw_step_t.fill_(self.state[p]['step'])
211
+ adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
212
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
213
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
214
+
215
+ def _step_muon(self, group):
216
+ params = [p for p in group['params'] if p.grad is not None]
217
+ if not params:
218
+ return
219
+ p = params[0]
220
+ state = self.state[p]
221
+ num_params = len(params)
222
+ shape, device, dtype = p.shape, p.device, p.dtype
223
+ if "momentum_buffer" not in state:
224
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
225
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
226
+ if "second_momentum_buffer" not in state:
227
+ # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
228
+ full_shape = (num_params, *shape)
229
+ state_shape = list(full_shape)
230
+ state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
231
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
232
+ # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
233
+ # This was the autograd-safety fix that unblocks grad_accum>=2.
234
+ stacked_grads = torch.stack([p.grad for p in params])
235
+ stacked_params = torch.stack(params)
236
+ self._muon_momentum_t.fill_(group["momentum"])
237
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
238
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
239
+ self._muon_wd_t.fill_(group["weight_decay"])
240
+ muon_step_fused(stacked_grads, stacked_params,
241
+ state["momentum_buffer"], state["second_momentum_buffer"],
242
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
243
+ self._muon_beta2_t, group["ns_steps"], red_dim)
244
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
245
+
246
+ @torch.no_grad()
247
+ def step(self):
248
+ for group in self.param_groups:
249
+ if group['kind'] == 'adamw':
250
+ self._step_adamw(group)
251
+ elif group['kind'] == 'muon':
252
+ self._step_muon(group)
overlay/subsystems/htm.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HTM torch wrapper around the pyo3 ``htm_rust`` crate.
3
+
4
+ Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to
5
+ ``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream
6
+ and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly
7
+ score. HTM learning is Hebbian (not gradient), so the wrapper runs under
8
+ ``torch.no_grad()``. Downstream layers carry gradients back to the embedding
9
+ via their own learnable projection from the binary column output.
10
+
11
+ Per-sequence state semantics
12
+ ---------------------------
13
+ Training-time forward passes are independent windows of tokens (re-sampled
14
+ every step), so carrying TM state across calls would mix unrelated contexts.
15
+ This layer calls ``reset()`` on every region at the top of ``forward``; the
16
+ TM learns within-window temporal patterns only. Users that want cross-window
17
+ continuity (e.g. eval over a long document) should instead construct the
18
+ layer and drive ``step_stream`` themselves (not implemented here; the
19
+ single-forward contract is sufficient for the autoresearch loop).
20
+
21
+ Device handling
22
+ ---------------
23
+ ``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a
24
+ ``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back
25
+ to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this
26
+ copy is small compared to the SP/TM compute.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import time
32
+ from concurrent.futures import ThreadPoolExecutor
33
+
34
+ import numpy as np
35
+ import torch
36
+ import torch.nn as nn
37
+
38
+ import htm_rust
39
+
40
+ # step_many releases the GIL for the whole pass, so multiple threads can
41
+ # truly run regions in parallel — wall-clock scales with B up to CPU cores.
42
+ _HTM_HAS_STEP_MANY = hasattr(htm_rust.HTMRegion, "step_many")
43
+ # GPU backend: built with `maturin develop --features gpu`. One CUDA region
44
+ # per batch slot, persistent device state for SP synapses. Transparent
45
+ # fallback to CPU when not available.
46
+ _HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu")
47
+ # Zero-copy CUDA path: consumes torch CUDA tensors directly via the
48
+ # __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
49
+ # and the D2H of outputs. Huge win when the input SDR already lives on GPU
50
+ # (which is the train.py hot path — retina is a device buffer).
51
+ _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_cuda")
52
+ # Fused megakernel path: collapses all T timesteps + SP + TM into a single
53
+ # CUDA launch per forward. Replaces global top-K with per-column threshold
54
+ # inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
55
+ # Opt-in via env var (default on when available).
56
+ import os as _os_fused
57
+ _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_fused_cuda")
58
+ _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
59
+
60
+
61
+ class HTMLayer(nn.Module):
62
+ """Batched torch wrapper around ``htm_rust.HTMRegion``.
63
+
64
+ One independent region per batch slot so temporal memory learns
65
+ sequence-local patterns without cross-batch bleed. Regions grow
66
+ lazily if a larger batch shows up.
67
+
68
+ Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are
69
+ the binary active-column mask (float32 0/1) and the last channel is
70
+ the per-timestep anomaly score in [0, 1].
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ input_bits: int = 16384,
76
+ n_columns: int = 2048,
77
+ cells_per_column: int = 32,
78
+ batch_size: int = 1,
79
+ seed: int = 42,
80
+ learn: bool = True,
81
+ reset_each_forward: bool = True,
82
+ use_gpu: bool | None = None,
83
+ ) -> None:
84
+ super().__init__()
85
+ self.input_bits = input_bits
86
+ self.n_columns = n_columns
87
+ self.cells_per_column = cells_per_column
88
+ self.learn = learn
89
+ self.reset_each_forward = reset_each_forward
90
+ self._seed_base = seed
91
+ # Learn gating: HTM learn kernels (tm_punish, tm_learn_reinforce, tm_grow)
92
+ # are 56% of total HTM CUDA time. Gating them to run every N forwards
93
+ # instead of every forward cuts HTM cost ~2x. Hebbian learning still
94
+ # converges since the EMA accumulates over many calls. Env:
95
+ # HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
96
+ import os as _os
97
+ self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1")))
98
+ self._forward_counter = 0
99
+ # GPU backend gate. Default: auto-detect — use GPU when the pyo3
100
+ # module was built with --features gpu AND CUDA is actually usable.
101
+ if use_gpu is None:
102
+ use_gpu = _HTM_HAS_GPU and torch.cuda.is_available()
103
+ elif use_gpu and not _HTM_HAS_GPU:
104
+ raise RuntimeError(
105
+ "HTMLayer(use_gpu=True) but htm_rust was not built with "
106
+ "--features gpu. Re-run `maturin develop --features gpu`."
107
+ )
108
+ self._use_gpu = bool(use_gpu)
109
+ cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion
110
+ self._region_cls = cls
111
+ self._regions = [
112
+ cls(input_bits, n_columns, cells_per_column, seed + i)
113
+ for i in range(batch_size)
114
+ ]
115
+ self.register_buffer("_dummy", torch.zeros(1), persistent=False)
116
+ import os as _os
117
+ self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16))
118
+
119
+ def _ensure_regions(self, B: int) -> None:
120
+ while len(self._regions) < B:
121
+ idx = len(self._regions)
122
+ self._regions.append(
123
+ self._region_cls(
124
+ self.input_bits,
125
+ self.n_columns,
126
+ self.cells_per_column,
127
+ self._seed_base + idx,
128
+ )
129
+ )
130
+
131
+ def reset(self) -> None:
132
+ """Clear TM predictive state on every region (keeps SP synapses)."""
133
+ for r in self._regions:
134
+ r.reset()
135
+
136
+ @torch.no_grad()
137
+ def forward(self, sdr: torch.Tensor) -> torch.Tensor:
138
+ B, T, D = sdr.shape
139
+ if D != self.input_bits:
140
+ raise ValueError(f"expected input_bits={self.input_bits}, got {D}")
141
+ self._ensure_regions(B)
142
+ if self.reset_each_forward:
143
+ self.reset()
144
+
145
+ # Learn-gate: run learn kernels only every N forwards (skips 56% of
146
+ # HTM CUDA time on skip-forwards; Hebbian EMA still converges).
147
+ self._forward_counter += 1
148
+ learn = bool(
149
+ self.learn
150
+ and self.training
151
+ and (self._forward_counter % self._learn_every == 0)
152
+ )
153
+
154
+ # Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
155
+ # so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
156
+ # kernel writes directly into torch-owned CUDA tensors via CAI.
157
+ # Gives 5-10x tok/s on train.py vs the numpy path below.
158
+ if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
159
+ sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
160
+ cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
161
+ anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
162
+ # Pick fused (1 launch) or legacy (12*T launches) path.
163
+ if _HTM_USE_FUSED:
164
+ for b in range(B):
165
+ self._regions[b].step_many_fused_cuda(
166
+ sdr_u8[b].__cuda_array_interface__,
167
+ cols_out[b].__cuda_array_interface__,
168
+ anom_out[b].__cuda_array_interface__,
169
+ learn,
170
+ )
171
+ else:
172
+ for b in range(B):
173
+ self._regions[b].step_many_cuda(
174
+ sdr_u8[b].__cuda_array_interface__,
175
+ cols_out[b].__cuda_array_interface__,
176
+ anom_out[b].__cuda_array_interface__,
177
+ learn,
178
+ )
179
+ # Assemble (B, T, n_cols+1) — keep bf16-friendly float32.
180
+ return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1)
181
+
182
+ # Fallback: CPU / numpy path. Kept for CPU-input case and for
183
+ # builds without CAI support.
184
+ sdr_np = sdr.detach().cpu().contiguous().numpy().view(np.bool_)
185
+ out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32)
186
+
187
+ def _process_one(b: int) -> None:
188
+ region = self._regions[b]
189
+ if self._use_gpu:
190
+ cols, anom = region.step_many_gpu(sdr_np[b], learn)
191
+ out[b, :, : self.n_columns] = cols
192
+ out[b, :, self.n_columns] = anom
193
+ elif _HTM_HAS_STEP_MANY:
194
+ # Single Rust call: T steps with GIL released for the whole pass.
195
+ cols, anom = region.step_many(sdr_np[b], learn) # cols (T, n_cols), anom (T,)
196
+ out[b, :, : self.n_columns] = cols
197
+ out[b, :, self.n_columns] = anom
198
+ else:
199
+ for t in range(T):
200
+ active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn)
201
+ out[b, t, : self.n_columns] = active_cols
202
+ out[b, t, self.n_columns] = float(anomaly)
203
+
204
+ if B == 1:
205
+ _process_one(0)
206
+ elif self._use_gpu:
207
+ # GPU regions share the CUDA context; serialise to avoid contention
208
+ # for stream 0. Per-region latency is dominated by kernel compute,
209
+ # not threadable on a single stream cheaply — future work: one
210
+ # CUDA stream per region.
211
+ for b in range(B):
212
+ _process_one(b)
213
+ else:
214
+ # Each thread runs in pure Rust under py.allow_threads, so they
215
+ # parallelise to wall-clock min(B, CPU_cores).
216
+ list(self._htm_pool.map(_process_one, range(B)))
217
+
218
+ return torch.from_numpy(out).to(sdr.device)
219
+
220
+ def forward_async(self, sdr: torch.Tensor):
221
+ """Submit HTM work and return a handle awaitable via ``forward_await``.
222
+
223
+ On the CAI zero-copy path (GPU tensor in, GPU region), the Rust
224
+ CUDA kernels are launched on cudarc's internal stream and control
225
+ returns **immediately** — no device synchronization. The caller's
226
+ next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued
227
+ on PyTorch's default stream and can execute while HTM kernels run
228
+ on the cudarc stream. ``forward_await`` performs the cross-stream
229
+ sync (via ``device_sync``) and assembles the output tensor only
230
+ when the result is actually consumed.
231
+
232
+ For cooperative kernels (``step_many_fused_cuda``) the GPU can only
233
+ run one cooperative launch at a time, so kernel-level overlap with
234
+ default-stream work is limited. The win is **CPU-side launch
235
+ overlap**: instead of the CPU blocking ~10 ms waiting for HTM
236
+ before it can even enqueue wte/mamba, it enqueues everything up
237
+ front and the GPU executes back-to-back without CPU stalls.
238
+
239
+ On the legacy CPU/numpy path, work is dispatched to a thread pool
240
+ as before."""
241
+ B, T, D = sdr.shape
242
+ if D != self.input_bits:
243
+ raise ValueError(f"expected input_bits={self.input_bits}, got {D}")
244
+ self._ensure_regions(B)
245
+ if self.reset_each_forward:
246
+ self.reset()
247
+ learn = bool(self.learn and self.training)
248
+
249
+ if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
250
+ sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
251
+ cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
252
+ anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
253
+ # ONE cooperative kernel launch for all B regions. Breaks past
254
+ # the CUDA cooperative-kernel device-level serialization (only
255
+ # one cooperative kernel runs at a time). A single launch with
256
+ # grid.y = B processes all regions concurrently — ~B× speedup.
257
+ # Falls back to sequential dispatch if the batched entry isn't
258
+ # available (older htm_rust wheel).
259
+ if _HTM_USE_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"):
260
+ # Slice self._regions to match B: _ensure_regions may have
261
+ # allocated more regions than the current batch size needs
262
+ # (e.g. factual eval uses smaller batches than training).
263
+ try:
264
+ htm_rust.step_batch_fused_cuda(
265
+ self._regions[:B],
266
+ [sdr_u8[b].__cuda_array_interface__ for b in range(B)],
267
+ [cols_out[b].__cuda_array_interface__ for b in range(B)],
268
+ [anom_out[b].__cuda_array_interface__ for b in range(B)],
269
+ learn,
270
+ )
271
+ except RuntimeError as _e:
272
+ if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
273
+ # Batch too large for cooperative grid. Fall back to
274
+ # sequential per-region fused launches (each B=1).
275
+ for b in range(B):
276
+ self._regions[b].step_many_fused_cuda(
277
+ sdr_u8[b].__cuda_array_interface__,
278
+ cols_out[b].__cuda_array_interface__,
279
+ anom_out[b].__cuda_array_interface__,
280
+ learn,
281
+ )
282
+ else:
283
+ raise
284
+ elif _HTM_USE_FUSED:
285
+ for b in range(B):
286
+ self._regions[b].step_many_fused_cuda(
287
+ sdr_u8[b].__cuda_array_interface__,
288
+ cols_out[b].__cuda_array_interface__,
289
+ anom_out[b].__cuda_array_interface__,
290
+ learn,
291
+ )
292
+ else:
293
+ for b in range(B):
294
+ self._regions[b].step_many_cuda(
295
+ sdr_u8[b].__cuda_array_interface__,
296
+ cols_out[b].__cuda_array_interface__,
297
+ anom_out[b].__cuda_array_interface__,
298
+ learn,
299
+ )
300
+ # NO sync here — kernels are in-flight on cudarc's stream.
301
+ # forward_await() will sync before the output is consumed.
302
+ return {
303
+ 'cuda_deferred': True,
304
+ 'cols_out': cols_out,
305
+ 'anom_out': anom_out,
306
+ 'region0': self._regions[0],
307
+ }
308
+
309
+ sdr_np = sdr.detach().cpu().contiguous().numpy().view(np.bool_)
310
+ out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32)
311
+
312
+ def _process_one(b):
313
+ region = self._regions[b]
314
+ if self._use_gpu:
315
+ cols, anom = region.step_many_gpu(sdr_np[b], learn)
316
+ out[b, :, : self.n_columns] = cols
317
+ out[b, :, self.n_columns] = anom
318
+ elif _HTM_HAS_STEP_MANY:
319
+ cols, anom = region.step_many(sdr_np[b], learn)
320
+ out[b, :, : self.n_columns] = cols
321
+ out[b, :, self.n_columns] = anom
322
+ else:
323
+ for t in range(T):
324
+ active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn)
325
+ out[b, t, : self.n_columns] = active_cols
326
+ out[b, t, self.n_columns] = float(anomaly)
327
+
328
+ fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)])
329
+ return {'fut': fut, 'out': out, 'device': sdr.device}
330
+
331
+ def forward_await(self, handle) -> torch.Tensor:
332
+ if handle.get('cuda_deferred'):
333
+ # Cross-stream sync: block until cudarc stream finishes HTM
334
+ # kernels so the output tensors are safe to read on the
335
+ # default stream.
336
+ region0 = handle['region0']
337
+ if hasattr(region0, "device_sync"):
338
+ region0.device_sync()
339
+ else:
340
+ torch.cuda.synchronize()
341
+ cols_out = handle['cols_out']
342
+ anom_out = handle['anom_out']
343
+ return torch.cat(
344
+ (cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1
345
+ )
346
+ if 'cuda_result' in handle:
347
+ return handle['cuda_result']
348
+ handle['fut'].result()
349
+ return torch.from_numpy(handle['out']).to(handle['device'])
350
+
351
+
352
+ if __name__ == "__main__":
353
+ torch.manual_seed(0)
354
+
355
+ # Smoke test: (B=2, T=4, D=16384) random 2%-sparse SDR
356
+ B, T, D = 2, 4, 16384
357
+ n_columns = 2048
358
+ target_active_in = int(D * 0.02) # 327
359
+
360
+ layer = HTMLayer(
361
+ input_bits=D,
362
+ n_columns=n_columns,
363
+ cells_per_column=32,
364
+ batch_size=B,
365
+ seed=42,
366
+ learn=True,
367
+ )
368
+ layer.train()
369
+
370
+ rng = np.random.default_rng(0)
371
+ sdr = np.zeros((B, T, D), dtype=bool)
372
+ for b in range(B):
373
+ for t in range(T):
374
+ idx = rng.choice(D, size=target_active_in, replace=False)
375
+ sdr[b, t, idx] = True
376
+ sdr_t = torch.from_numpy(sdr)
377
+
378
+ t0 = time.perf_counter()
379
+ out = layer(sdr_t)
380
+ dt_first = time.perf_counter() - t0
381
+
382
+ assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}"
383
+ assert out.dtype == torch.float32, f"dtype {out.dtype}"
384
+
385
+ active_cols = out[..., :n_columns]
386
+ anomaly = out[..., n_columns]
387
+
388
+ col_sums = active_cols.sum(dim=-1) # (B, T)
389
+ mean_active = col_sums.float().mean().item()
390
+ expected = n_columns * 0.02 # ≈ 40.96
391
+ assert 20 <= mean_active <= 60, (
392
+ f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})"
393
+ )
394
+
395
+ # t=0 has no TM prediction → anomaly = 1.0 on every batch slot.
396
+ assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}"
397
+
398
+ # Second forward on same (reset) layer: identical shapes, deterministic re-run possible.
399
+ t0 = time.perf_counter()
400
+ out2 = layer(sdr_t)
401
+ dt_second = time.perf_counter() - t0
402
+ assert out2.shape == out.shape
403
+
404
+ # Repeating-sequence anomaly decay check — one region, T=8 repeats of same pattern.
405
+ rep_layer = HTMLayer(
406
+ input_bits=D,
407
+ n_columns=n_columns,
408
+ batch_size=1,
409
+ seed=7,
410
+ learn=True,
411
+ )
412
+ rep_layer.train()
413
+ base = torch.zeros(D, dtype=torch.bool)
414
+ idx = rng.choice(D, size=target_active_in, replace=False)
415
+ base[idx] = True
416
+ rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone()
417
+ rep_out = rep_layer(rep)
418
+ rep_anom = rep_out[0, :, n_columns]
419
+ assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}"
420
+ assert rep_anom[-1].item() < rep_anom[0].item(), (
421
+ f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}"
422
+ )
423
+
424
+ print("[OK] shape:", tuple(out.shape))
425
+ print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})")
426
+ print(f"[OK] t=0 anomaly = 1.0 on all batch slots")
427
+ print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}")
428
+ print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms "
429
+ f"on (B={B}, T={T}, D={D})")
overlay/subsystems/sdr_retina.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Offline Semantic Folding SDR Retina (Cortical.io-grade).
3
+
4
+ Builds a topographic, semantic-folding Sparse Distributed Representation (SDR)
5
+ for every token in the vocabulary, following Webber 2015 ("Semantic Folding Theory").
6
+
7
+ Pipeline:
8
+ 1. Scan the tokenized training corpus (parquet shards at ~/.cache/autoresearch/data).
9
+ We on-the-fly tokenize ~10M tokens from the first few shards.
10
+ 2. For each token, build a context vector = top-K most-associated neighbors
11
+ (±8-token window, PMI ranking).
12
+ 3. Train a 128x128 = 16384-bit Kohonen SOM on those context vectors so that
13
+ semantically related context features land on neighboring lattice cells.
14
+ 4. For each token, compute its folded SDR: union of the lattice cells whose
15
+ BMUs are triggered by its top-K context features. Then per-row quantile
16
+ threshold to exactly 2% active bits (327 / 16384).
17
+ 5. Save to ~/.cache/autoresearch/retina.npz.
18
+
19
+ Entry point:
20
+ uv run python subsystems/sdr_retina.py --build --validate
21
+
22
+ The validation asserts classic Cortical.io-style analogies:
23
+ - overlap("the", "a") > overlap("the", "zebra")
24
+ - overlap("man", "woman") > overlap("man", "rock")
25
+ - overlap("king","queen") > overlap("king", "dinosaur")
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import argparse
31
+ import math
32
+ import os
33
+ import sys
34
+ import time
35
+ from dataclasses import dataclass
36
+
37
+ import numpy as np
38
+ import pyarrow.parquet as pq
39
+ import torch
40
+
41
+ # Make the parent repo importable so we can reuse the Tokenizer
42
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
43
+ sys.path.insert(0, REPO_ROOT)
44
+
45
+ from prepare import CACHE_DIR, DATA_DIR, TOKENIZER_DIR, VAL_FILENAME, Tokenizer # noqa: E402
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Build parameters
50
+ # ---------------------------------------------------------------------------
51
+
52
+ RETINA_PATH = os.path.join(CACHE_DIR, "retina.npz")
53
+
54
+ GRID_H = 128
55
+ GRID_W = 128
56
+ N_BITS = GRID_H * GRID_W # 16384
57
+ TARGET_SPARSITY = 0.02 # 2% (default, Cortical.io-style)
58
+ # Default = int(floor(N_BITS * TARGET_SPARSITY)) = 327, matches Webber/Numenta.
59
+ # Override via HYDRA_SDR_TARGET_ACTIVE env var. The cache key encodes
60
+ # target_active, so changing this triggers automatic retina regeneration.
61
+ TARGET_ACTIVE = int(os.environ.get(
62
+ "HYDRA_SDR_TARGET_ACTIVE",
63
+ str(int(N_BITS * TARGET_SPARSITY)),
64
+ ))
65
+
66
+ CONTEXT_WINDOW = 8 # +/- 8 tokens
67
+ TOP_K_FEATURES = 64 # top-K context features per token
68
+ # SCALES WITH VOCAB — need ~100+ occurrences per token for stable cooccurrence.
69
+ # At V=8k: 10M tokens = 1250/tok avg. At V=65k: 10M tokens = 153/tok avg
70
+ # (borderline); rare tokens seen <30x → noisy retina. Recommended: V*150.
71
+ # HF Hub cache makes this a one-time cost per vocab config anyway.
72
+ TARGET_TRAIN_TOKENS = int(os.environ.get("HYDRA_RETINA_TRAIN_TOKENS", "20000000"))
73
+ MAX_DOCS_PER_SHARD = 200_000 # safety cap per shard
74
+
75
+ # Kohonen SOM
76
+ SOM_EPOCHS = 50
77
+ SOM_SIGMA_START = 32.0
78
+ SOM_SIGMA_END = 1.0
79
+ SOM_ALPHA_START = 0.1
80
+ SOM_ALPHA_END = 0.001
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Small helpers
85
+ # ---------------------------------------------------------------------------
86
+
87
+ def _fmt(n):
88
+ if n >= 1_000_000:
89
+ return f"{n/1_000_000:.2f}M"
90
+ if n >= 1_000:
91
+ return f"{n/1_000:.1f}k"
92
+ return str(n)
93
+
94
+
95
+ def _device() -> torch.device:
96
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+
98
+
99
+ def _list_train_shards():
100
+ files = sorted(
101
+ f for f in os.listdir(DATA_DIR)
102
+ if f.endswith(".parquet") and not f.endswith(".tmp")
103
+ )
104
+ train = [os.path.join(DATA_DIR, f) for f in files if f != VAL_FILENAME]
105
+ assert len(train) > 0, f"No training shards at {DATA_DIR}. Run prepare.py first."
106
+ return train
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Stage 1: stream tokens from parquet shards and collect co-occurrences
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def _iter_tokenized_shards(tokenizer: Tokenizer, target_tokens: int):
114
+ """Yield 1-D int32 numpy arrays of token ids until target_tokens reached.
115
+
116
+ Two paths:
117
+ - HYDRA_USE_NEMOTRON=1: stream docs from Nemotron HF datasets (no shards
118
+ on disk — matches the streaming training path).
119
+ - Default: iterate local parquet shards (legacy prepare.py path).
120
+ """
121
+ tok_encode = tokenizer.enc.encode_ordinary_batch
122
+
123
+ if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
124
+ # Streaming path: reuse prepare_nemotron's weighted stream.
125
+ import prepare_nemotron as _pn
126
+ stream = _pn._WeightedStream(_pn._phase_weights(), seed=0)
127
+ seen = 0
128
+ batch: list[str] = []
129
+ BATCH = 512
130
+ while seen < target_tokens:
131
+ text, _epoch = next(stream)
132
+ if not text:
133
+ continue
134
+ batch.append(text)
135
+ if len(batch) < BATCH:
136
+ continue
137
+ token_lists = tok_encode(batch, num_threads=8)
138
+ batch = []
139
+ for ids in token_lists:
140
+ if not ids:
141
+ continue
142
+ arr = np.asarray(ids, dtype=np.int32)
143
+ yield arr
144
+ seen += arr.size
145
+ if seen >= target_tokens:
146
+ print(f" [nemotron-stream] yielded {_fmt(seen)} tokens, target reached")
147
+ return
148
+ return
149
+
150
+ # Legacy shard path.
151
+ shards = _list_train_shards()
152
+ seen = 0
153
+ for shard_idx, path in enumerate(shards):
154
+ if seen >= target_tokens:
155
+ return
156
+ pf = pq.ParquetFile(path)
157
+ shard_tokens = 0
158
+ for rg_idx in range(pf.num_row_groups):
159
+ rg = pf.read_row_group(rg_idx)
160
+ docs = rg.column("text").to_pylist()
161
+ if len(docs) > MAX_DOCS_PER_SHARD:
162
+ docs = docs[:MAX_DOCS_PER_SHARD]
163
+ # Batch-encode for throughput
164
+ batch_size = 512
165
+ for i in range(0, len(docs), batch_size):
166
+ batch = docs[i:i + batch_size]
167
+ token_lists = tok_encode(batch, num_threads=8)
168
+ for ids in token_lists:
169
+ if not ids:
170
+ continue
171
+ arr = np.asarray(ids, dtype=np.int32)
172
+ yield arr
173
+ shard_tokens += arr.size
174
+ seen += arr.size
175
+ if seen >= target_tokens:
176
+ print(f" shard {shard_idx}: yielded {_fmt(shard_tokens)} tokens "
177
+ f"(total {_fmt(seen)}), target reached")
178
+ return
179
+ print(f" shard {shard_idx}: yielded {_fmt(shard_tokens)} tokens (total {_fmt(seen)})")
180
+
181
+
182
+ def _cooccur_from_doc(ids: np.ndarray, window: int, vocab_size: int,
183
+ counts: np.ndarray, cooc: np.ndarray) -> None:
184
+ """Update unigram counts and cooccurrence counts for one document. Vectorized."""
185
+ n = ids.size
186
+ if n < 2:
187
+ return
188
+ # unigram counts
189
+ np.add.at(counts, ids, 1)
190
+ # For each offset d in 1..window, count pairs (ids[:-d], ids[d:])
191
+ # Both directions are equivalent by symmetry; we add both to keep the
192
+ # matrix symmetric and treat it as undirected context.
193
+ for d in range(1, window + 1):
194
+ left = ids[:-d]
195
+ right = ids[d:]
196
+ # symmetric update
197
+ flat_lr = left.astype(np.int64) * vocab_size + right.astype(np.int64)
198
+ flat_rl = right.astype(np.int64) * vocab_size + left.astype(np.int64)
199
+ # use bincount-style scatter via np.add.at on the flat view
200
+ cooc_flat = cooc.ravel()
201
+ np.add.at(cooc_flat, flat_lr, 1)
202
+ np.add.at(cooc_flat, flat_rl, 1)
203
+
204
+
205
+ def build_cooccurrence(tokenizer: Tokenizer, target_tokens: int, window: int) -> tuple[np.ndarray, np.ndarray, int]:
206
+ """
207
+ Stream tokens and build unigram + cooccurrence counts.
208
+ Returns (counts[V] int64, cooc[V,V] int32, total_tokens int).
209
+ """
210
+ vocab_size = tokenizer.get_vocab_size()
211
+ print(f"[1/4] Building cooccurrence (vocab={vocab_size}, window=+/-{window}, target={_fmt(target_tokens)} tokens)")
212
+ counts = np.zeros(vocab_size, dtype=np.int64)
213
+ # int32 is enough per-cell if we stay <= a few hundred million total tokens; guard with clip at save.
214
+ cooc = np.zeros((vocab_size, vocab_size), dtype=np.int32)
215
+
216
+ total = 0
217
+ n_docs = 0
218
+ t0 = time.time()
219
+ for ids in _iter_tokenized_shards(tokenizer, target_tokens):
220
+ _cooccur_from_doc(ids, window, vocab_size, counts, cooc)
221
+ total += ids.size
222
+ n_docs += 1
223
+ if n_docs % 5000 == 0:
224
+ dt = time.time() - t0
225
+ rate = total / max(dt, 1e-6)
226
+ print(f" docs={_fmt(n_docs)} tokens={_fmt(total)} ({rate/1000:.0f}k tok/s)")
227
+
228
+ dt = time.time() - t0
229
+ print(f"[1/4] done: {_fmt(total)} tokens, {_fmt(n_docs)} docs, {dt:.1f}s, "
230
+ f"cooc_nnz={int((cooc > 0).sum())}")
231
+ return counts, cooc, total
232
+
233
+
234
+ # ---------------------------------------------------------------------------
235
+ # Stage 2: build top-K context features (PMI)
236
+ # ---------------------------------------------------------------------------
237
+
238
+ def compute_pmi_topk(counts: np.ndarray, cooc: np.ndarray, total_tokens: int,
239
+ top_k: int) -> tuple[np.ndarray, np.ndarray]:
240
+ """
241
+ For each token, compute top-K context features by positive PMI.
242
+ Returns:
243
+ topk_idx : int32 [V, K] token ids of the top-K context features
244
+ topk_score : float32 [V, K] PMI scores (0 for padded missing features)
245
+ Missing features are padded with idx=token itself and score=0, so they
246
+ have a well-defined (but uninformative) column.
247
+ """
248
+ V = counts.shape[0]
249
+ print(f"[2/4] Computing PMI top-{top_k} per token (vocab={V})")
250
+
251
+ # window_pairs per occurrence: 2 * window (we added both directions, each offset twice).
252
+ # For the PMI denominator we need a total pair count; using coo.sum() is the clean
253
+ # per-matrix normalizer and avoids any constant confusion.
254
+ pair_total = float(cooc.sum())
255
+ if pair_total <= 0:
256
+ raise RuntimeError("Empty cooccurrence matrix")
257
+
258
+ # Run on GPU if available; this is ~8k x 8k float32 = 256MB each.
259
+ dev = _device()
260
+ cooc_t = torch.from_numpy(cooc.astype(np.float32)).to(dev)
261
+ counts_t = torch.from_numpy(counts.astype(np.float64)).to(dev).clamp_min(1.0)
262
+
263
+ # P(i) = counts[i] / total_tokens
264
+ # P(i, j) = cooc[i, j] / pair_total
265
+ # PMI = log(P(i,j) / (P(i) P(j)))
266
+ # Positive PMI = max(PMI, 0).
267
+ # We'll compute log-PMI in a numerically safe way:
268
+ # log(cooc) + log(total_tokens^2 / pair_total) - log(c_i) - log(c_j)
269
+ # Keep numerator zero where cooc==0 and mask those out.
270
+
271
+ log_const = math.log(total_tokens) + math.log(total_tokens) - math.log(pair_total)
272
+ log_ci = torch.log(counts_t) # [V]
273
+ log_cj = log_ci.clone() # same vector (symmetric vocab)
274
+
275
+ # We'll do it in row blocks to cap memory of intermediate log() tensors.
276
+ topk_idx = np.zeros((V, top_k), dtype=np.int32)
277
+ topk_score = np.zeros((V, top_k), dtype=np.float32)
278
+
279
+ block = 512
280
+ t0 = time.time()
281
+ for start in range(0, V, block):
282
+ end = min(V, start + block)
283
+ rows = cooc_t[start:end] # [b, V] int-as-float
284
+ mask = rows > 0
285
+ # log(rows) where rows>0; else keep -inf then mask out
286
+ log_rows = torch.where(mask, torch.log(rows.clamp_min(1.0)),
287
+ torch.full_like(rows, float("-inf")))
288
+ pmi = log_rows + log_const - log_ci[start:end].unsqueeze(1) - log_cj.unsqueeze(0)
289
+ ppmi = torch.where(mask, torch.clamp(pmi, min=0.0),
290
+ torch.full_like(pmi, float("-inf")))
291
+ # top-K along dim=1
292
+ vals, idx = torch.topk(ppmi, k=top_k, dim=1)
293
+ # Replace any -inf valued slots with score 0 and idx = the token itself
294
+ bad = torch.isneginf(vals)
295
+ if bad.any():
296
+ self_idx = torch.arange(start, end, device=dev).unsqueeze(1).expand_as(idx)
297
+ idx = torch.where(bad, self_idx, idx)
298
+ vals = torch.where(bad, torch.zeros_like(vals), vals)
299
+ topk_idx[start:end] = idx.cpu().numpy().astype(np.int32)
300
+ topk_score[start:end] = vals.cpu().numpy().astype(np.float32)
301
+
302
+ del cooc_t, counts_t
303
+ if dev.type == "cuda":
304
+ torch.cuda.empty_cache()
305
+ print(f"[2/4] done: top-{top_k} PMI features per token in {time.time()-t0:.1f}s")
306
+ return topk_idx, topk_score
307
+
308
+
309
+ # ---------------------------------------------------------------------------
310
+ # Stage 3: Kohonen SOM on the context-vector representation
311
+ # ---------------------------------------------------------------------------
312
+
313
+ def _context_vectors_from_topk(topk_idx: np.ndarray, topk_score: np.ndarray,
314
+ vocab_size: int) -> torch.Tensor:
315
+ """
316
+ Build the dense context matrix X [V, V] where X[i] is the top-K PMI context
317
+ vector for token i, L2-normalized. For V=8192 this is 8k x 8k float32 = 256 MB.
318
+ """
319
+ V = vocab_size
320
+ K = topk_idx.shape[1]
321
+ dev = _device()
322
+ X = torch.zeros((V, V), dtype=torch.float32, device=dev)
323
+ rows = torch.arange(V, device=dev).unsqueeze(1).expand(V, K) # [V,K]
324
+ idx = torch.from_numpy(topk_idx).to(dev).long()
325
+ scores = torch.from_numpy(topk_score).to(dev)
326
+ # Scatter scores into X at positions (rows, idx). If duplicates, keep max.
327
+ X[rows, idx] = torch.maximum(X[rows, idx], scores)
328
+ # L2 normalize so Euclidean ~ cosine
329
+ norm = X.norm(dim=1, keepdim=True).clamp_min(1e-8)
330
+ X = X / norm
331
+ return X
332
+
333
+
334
+ def train_som(X: torch.Tensor, grid_h: int, grid_w: int,
335
+ epochs: int, sigma_start: float, sigma_end: float,
336
+ alpha_start: float, alpha_end: float,
337
+ seed: int = 137) -> torch.Tensor:
338
+ """
339
+ Train a Kohonen SOM with rectangular grid and Gaussian neighborhood.
340
+ X: [V, F] features (L2 normalized). Returns weights W: [grid_h*grid_w, F].
341
+ """
342
+ dev = X.device
343
+ V, F = X.shape
344
+ N = grid_h * grid_w
345
+
346
+ torch.manual_seed(seed)
347
+ # Initialize SOM weights: small random linear combinations of data points
348
+ # (faster convergence than uniform random in the feature space).
349
+ init_pick = torch.randint(0, V, (N,), device=dev)
350
+ W = X[init_pick].clone() # [N, F]
351
+
352
+ # Precompute grid coordinates
353
+ yy, xx = torch.meshgrid(
354
+ torch.arange(grid_h, device=dev, dtype=torch.float32),
355
+ torch.arange(grid_w, device=dev, dtype=torch.float32),
356
+ indexing="ij",
357
+ )
358
+ grid = torch.stack([yy.reshape(-1), xx.reshape(-1)], dim=1) # [N, 2]
359
+
360
+ print(f"[3/4] Training Kohonen SOM: grid={grid_h}x{grid_w}, features={F}, "
361
+ f"epochs={epochs}, sigma {sigma_start}->{sigma_end}, alpha {alpha_start}->{alpha_end}")
362
+ t0 = time.time()
363
+
364
+ # Exponential decay schedules
365
+ def schedule(t_frac):
366
+ sigma = sigma_start * (sigma_end / sigma_start) ** t_frac
367
+ alpha = alpha_start * (alpha_end / alpha_start) ** t_frac
368
+ return sigma, alpha
369
+
370
+ # Batch-mode SOM: process a random permutation each epoch in mini-batches.
371
+ # For each mini-batch, compute BMUs then one vectorized neighborhood update.
372
+ batch_size = 256
373
+
374
+ for epoch in range(epochs):
375
+ t_frac = epoch / max(epochs - 1, 1)
376
+ sigma, alpha = schedule(t_frac)
377
+ two_sigma2 = 2.0 * sigma * sigma
378
+ perm = torch.randperm(V, device=dev)
379
+
380
+ for bstart in range(0, V, batch_size):
381
+ bidx = perm[bstart:bstart + batch_size]
382
+ xb = X[bidx] # [b, F]
383
+ # BMU: argmax of cosine similarity = argmin of squared Euclidean
384
+ # ||x||=||w||=1 for data; W may drift but the formulation remains stable.
385
+ sim = xb @ W.t() # [b, N]
386
+ bmu = sim.argmax(dim=1) # [b]
387
+
388
+ # Neighborhood weights h[b, n] = exp(-|grid[bmu_b] - grid[n]|^2 / (2*sigma^2))
389
+ bmu_coords = grid[bmu] # [b, 2]
390
+ diff = bmu_coords.unsqueeze(1) - grid.unsqueeze(0) # [b, N, 2]
391
+ dist2 = (diff * diff).sum(dim=2) # [b, N]
392
+ h = torch.exp(-dist2 / two_sigma2) # [b, N]
393
+ h = h * alpha # include LR
394
+
395
+ # Vectorized SOM update:
396
+ # W <- W + sum_b h[b] * (x_b - W) / (sum_b h[b])
397
+ # Batched form: numerator = h^T x_b [N, F], denom = h.sum(0) [N]
398
+ numer = h.t() @ xb # [N, F]
399
+ denom = h.sum(dim=0).unsqueeze(1).clamp_min(1e-8) # [N, 1]
400
+ target = numer / denom
401
+ # Update weight: mix toward target with a unit step (h already scaled by alpha).
402
+ # To prevent over-shoot when the same BMU is hit heavily, scale by the
403
+ # mean-field gain min(1, denom). Empirically this behaves like classic SOM.
404
+ gain = torch.clamp(h.sum(dim=0), max=1.0).unsqueeze(1) # [N,1]
405
+ W = (1 - gain) * W + gain * target
406
+
407
+ # Renormalize weights to unit sphere for stability
408
+ W = W / W.norm(dim=1, keepdim=True).clamp_min(1e-8)
409
+
410
+ if (epoch + 1) % max(1, epochs // 10) == 0 or epoch == 0:
411
+ dt = time.time() - t0
412
+ print(f" epoch {epoch+1}/{epochs} sigma={sigma:.2f} alpha={alpha:.4f} elapsed={dt:.1f}s")
413
+
414
+ print(f"[3/4] SOM trained in {time.time()-t0:.1f}s")
415
+ return W
416
+
417
+
418
+ # ---------------------------------------------------------------------------
419
+ # Stage 4: fold context vectors into SDRs
420
+ # ---------------------------------------------------------------------------
421
+
422
+ def fold_sdrs(X: torch.Tensor, W: torch.Tensor, topk_idx: np.ndarray,
423
+ topk_score: np.ndarray, target_active: int) -> np.ndarray:
424
+ """
425
+ For each token, activate the 'cell votes' on the lattice for each of its top-K
426
+ context features, then threshold to exactly target_active bits.
427
+
428
+ Implementation detail: every token in the vocabulary has a SOM BMU given its
429
+ context vector X[i]. We use those BMUs as the feature->cell map. For token t,
430
+ we accumulate votes at BMU(feature) weighted by the PMI score, then pick the
431
+ top target_active cells.
432
+ """
433
+ dev = X.device
434
+ V, F = X.shape
435
+ N = W.shape[0]
436
+ print(f"[4/4] Folding SDRs (V={V}, N={N}, target_active={target_active})")
437
+
438
+ # Per-feature BMU: for each token f as a feature, BMU_f = argmax_n W[n] . X[f]
439
+ # Chunked matmul to bound memory.
440
+ bmu = torch.empty(V, dtype=torch.long, device=dev)
441
+ chunk = 1024
442
+ for s in range(0, V, chunk):
443
+ e = min(V, s + chunk)
444
+ sim = X[s:e] @ W.t() # [b, N]
445
+ bmu[s:e] = sim.argmax(dim=1)
446
+
447
+ # Now build votes tensor [V, N] = sum over k of score[i, k] delta(n = bmu[feat[i, k]])
448
+ K = topk_idx.shape[1]
449
+ feat = torch.from_numpy(topk_idx).to(dev).long() # [V, K]
450
+ sc = torch.from_numpy(topk_score).to(dev) # [V, K]
451
+ feat_bmu = bmu[feat] # [V, K]
452
+
453
+ votes = torch.zeros((V, N), dtype=torch.float32, device=dev)
454
+ votes.scatter_add_(1, feat_bmu, sc)
455
+
456
+ # Tiny numerical nudge: add a local Gaussian kernel around each voted cell so
457
+ # near-neighbors accumulate mass (this is the "folding" smear). Kernel radius 1.
458
+ # Implement as a separable 3x3 blur on the 2D grid view.
459
+ grid_h = int(round(math.sqrt(N)))
460
+ grid_w = grid_h
461
+ assert grid_h * grid_w == N
462
+ votes_2d = votes.view(V, 1, grid_h, grid_w)
463
+ blur = torch.tensor([[[[0.5, 1.0, 0.5],
464
+ [1.0, 2.0, 1.0],
465
+ [0.5, 1.0, 0.5]]]], device=dev, dtype=torch.float32)
466
+ blur = blur / blur.sum()
467
+ votes_2d = torch.nn.functional.conv2d(votes_2d, blur, padding=1)
468
+ votes = votes_2d.view(V, N)
469
+
470
+ # Per-row top-target_active
471
+ _, top_cells = torch.topk(votes, k=target_active, dim=1)
472
+ sdr = torch.zeros((V, N), dtype=torch.bool, device=dev)
473
+ sdr.scatter_(1, top_cells, True)
474
+
475
+ # Sanity check
476
+ row_active = sdr.sum(dim=1)
477
+ assert int(row_active.min()) == target_active, "row active mismatch"
478
+ assert int(row_active.max()) == target_active, "row active mismatch"
479
+
480
+ return sdr.cpu().numpy()
481
+
482
+
483
+ # ---------------------------------------------------------------------------
484
+ # Build orchestration
485
+ # ---------------------------------------------------------------------------
486
+
487
+ @dataclass
488
+ class BuildReport:
489
+ vocab_size: int
490
+ n_bits: int
491
+ train_tokens: int
492
+ wall_time_sec: float
493
+
494
+
495
+ def _retina_cache_repo() -> str:
496
+ return os.environ.get("HYDRA_RETINA_CACHE_REPO", "icarus112/feather-retina-cache")
497
+
498
+
499
+ def _retina_cache_key() -> str:
500
+ """Cache key encodes vocab_size + n_bits + target_active so we don't
501
+ accidentally restore a retina built for a different tokenizer/config."""
502
+ try:
503
+ from prepare import VOCAB_SIZE
504
+ except Exception:
505
+ VOCAB_SIZE = 0
506
+ return f"retina_v{VOCAB_SIZE}_n{N_BITS}_a{TARGET_ACTIVE}.npz"
507
+
508
+
509
+ def _try_hydrate_retina_from_hub() -> bool:
510
+ """Attempt to download a pre-built retina matching our config from HF Hub.
511
+ Returns True if successful — caller should skip the rebuild."""
512
+ token = os.environ.get("HF_TOKEN")
513
+ if not token:
514
+ return False
515
+ cache_key = _retina_cache_key()
516
+ try:
517
+ from huggingface_hub import hf_hub_download
518
+ p = hf_hub_download(
519
+ repo_id=_retina_cache_repo(), repo_type="dataset",
520
+ filename=cache_key, token=token,
521
+ )
522
+ os.makedirs(CACHE_DIR, exist_ok=True)
523
+ import shutil
524
+ shutil.copy(p, RETINA_PATH)
525
+ # Quick verify shape
526
+ with np.load(RETINA_PATH) as npz:
527
+ if int(npz["n_bits"]) == N_BITS and int(npz["target_active"]) == TARGET_ACTIVE:
528
+ print(f"[retina-cache] hydrated {cache_key} from {_retina_cache_repo()} "
529
+ f"(shape={npz['sdr'].shape})", flush=True)
530
+ return True
531
+ os.remove(RETINA_PATH)
532
+ return False
533
+ except Exception as e:
534
+ print(f"[retina-cache] miss: {e}", flush=True)
535
+ return False
536
+
537
+
538
+ def _upload_retina_to_hub() -> None:
539
+ """Upload freshly-built retina.npz to HF Hub for reuse by future jobs."""
540
+ token = os.environ.get("HF_TOKEN")
541
+ if not token:
542
+ return
543
+ cache_key = _retina_cache_key()
544
+ try:
545
+ from huggingface_hub import HfApi, create_repo
546
+ create_repo(_retina_cache_repo(), repo_type="dataset", private=True,
547
+ exist_ok=True, token=token)
548
+ HfApi(token=token).upload_file(
549
+ path_or_fileobj=RETINA_PATH,
550
+ path_in_repo=cache_key,
551
+ repo_id=_retina_cache_repo(), repo_type="dataset",
552
+ commit_message=f"retina build for {cache_key}", token=token,
553
+ )
554
+ print(f"[retina-cache] uploaded {cache_key} to {_retina_cache_repo()}", flush=True)
555
+ except Exception as e:
556
+ print(f"[retina-cache] upload failed: {e}", flush=True)
557
+
558
+
559
+ def build_retina(target_tokens: int = TARGET_TRAIN_TOKENS) -> BuildReport:
560
+ # Try HF Hub-backed cache first — retina build takes 500+ seconds.
561
+ if os.path.exists(RETINA_PATH):
562
+ print(f"[retina-cache] using local {RETINA_PATH}", flush=True)
563
+ with np.load(RETINA_PATH) as npz:
564
+ return BuildReport(
565
+ vocab_size=int(npz["vocab_size"]),
566
+ n_bits=int(npz["n_bits"]),
567
+ train_tokens=int(npz["train_tokens"]),
568
+ wall_time_sec=0.0,
569
+ )
570
+ elif _try_hydrate_retina_from_hub():
571
+ # Local copy now populated; return stub report
572
+ with np.load(RETINA_PATH) as npz:
573
+ return BuildReport(
574
+ vocab_size=int(npz["vocab_size"]),
575
+ n_bits=int(npz["n_bits"]),
576
+ train_tokens=int(npz["train_tokens"]),
577
+ wall_time_sec=0.0,
578
+ )
579
+
580
+ tokenizer = Tokenizer.from_directory(TOKENIZER_DIR)
581
+ vocab_size = tokenizer.get_vocab_size()
582
+
583
+ t0 = time.time()
584
+
585
+ counts, cooc, total_tokens = build_cooccurrence(
586
+ tokenizer, target_tokens=target_tokens, window=CONTEXT_WINDOW,
587
+ )
588
+ topk_idx, topk_score = compute_pmi_topk(
589
+ counts, cooc, total_tokens=total_tokens, top_k=TOP_K_FEATURES,
590
+ )
591
+ # Free the big cooccurrence matrix before GPU-heavy stages
592
+ del cooc
593
+ X = _context_vectors_from_topk(topk_idx, topk_score, vocab_size)
594
+ W = train_som(
595
+ X, grid_h=GRID_H, grid_w=GRID_W,
596
+ epochs=SOM_EPOCHS,
597
+ sigma_start=SOM_SIGMA_START, sigma_end=SOM_SIGMA_END,
598
+ alpha_start=SOM_ALPHA_START, alpha_end=SOM_ALPHA_END,
599
+ )
600
+ sdr = fold_sdrs(X, W, topk_idx, topk_score, target_active=TARGET_ACTIVE)
601
+
602
+ wall = time.time() - t0
603
+
604
+ os.makedirs(CACHE_DIR, exist_ok=True)
605
+ np.savez_compressed(
606
+ RETINA_PATH,
607
+ sdr=sdr,
608
+ vocab_size=np.int64(vocab_size),
609
+ n_bits=np.int64(N_BITS),
610
+ grid_h=np.int64(GRID_H),
611
+ grid_w=np.int64(GRID_W),
612
+ target_active=np.int64(TARGET_ACTIVE),
613
+ context_window=np.int64(CONTEXT_WINDOW),
614
+ top_k_features=np.int64(TOP_K_FEATURES),
615
+ train_tokens=np.int64(total_tokens),
616
+ )
617
+ print(f"[save] wrote {RETINA_PATH} sdr.shape={sdr.shape} "
618
+ f"active_per_row={int(sdr.sum(axis=1).mean())} wall={wall:.1f}s")
619
+
620
+ # Push to HF Hub so subsequent jobs (and parallel retina experiments)
621
+ # skip the 500+ second build entirely.
622
+ _upload_retina_to_hub()
623
+
624
+ return BuildReport(
625
+ vocab_size=vocab_size,
626
+ n_bits=N_BITS,
627
+ train_tokens=total_tokens,
628
+ wall_time_sec=wall,
629
+ )
630
+
631
+
632
+