Spaces:
Runtime error
Runtime error
Update Feather H200 training runtime image
Browse files- Dockerfile +116 -0
- entrypoint.py +227 -0
- mamba_ssm_init.py +69 -0
- overlay/htm_rust/src/gpu/fused.rs +650 -0
- overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu +677 -0
- overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu +59 -0
- overlay/htm_rust/src/gpu/kernels/sp_duty.cu +45 -0
- overlay/htm_rust/src/gpu/kernels/sp_learn.cu +45 -0
- overlay/htm_rust/src/gpu/kernels/sp_overlap.cu +78 -0
- overlay/htm_rust/src/gpu/kernels/sp_topk.cu +117 -0
- overlay/htm_rust/src/gpu/kernels/tm_activate.cu +66 -0
- overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu +43 -0
- overlay/htm_rust/src/gpu/kernels/tm_grow.cu +155 -0
- overlay/htm_rust/src/gpu/kernels/tm_learn.cu +75 -0
- overlay/htm_rust/src/gpu/kernels/tm_predict.cu +102 -0
- overlay/htm_rust/src/gpu/kernels/tm_punish.cu +64 -0
- overlay/htm_rust/src/gpu/kernels/tm_reset.cu +36 -0
- overlay/htm_rust/src/gpu/mod.rs +549 -0
- overlay/htm_rust/src/gpu/sp_gpu.rs +796 -0
- overlay/htm_rust/src/gpu/tests.rs +643 -0
- overlay/htm_rust/src/gpu/tm_gpu.rs +460 -0
- overlay/hydra/eval.py +210 -0
- overlay/hydra/model.py +659 -0
- overlay/hydra/optimizer.py +252 -0
- overlay/subsystems/htm.py +429 -0
- overlay/subsystems/sdr_retina.py +632 -0
Dockerfile
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
PIP_NO_CACHE_DIR=1 \
|
| 5 |
+
PYTHONUNBUFFERED=1 \
|
| 6 |
+
CARGO_HOME=/root/.cargo \
|
| 7 |
+
RUSTUP_HOME=/root/.rustup \
|
| 8 |
+
PATH=/root/.cargo/bin:${PATH}
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
git curl ca-certificates build-essential pkg-config libssl-dev && \
|
| 12 |
+
rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
|
| 15 |
+
|
| 16 |
+
RUN pip install --upgrade pip setuptools wheel && \
|
| 17 |
+
pip install \
|
| 18 |
+
maturin \
|
| 19 |
+
huggingface_hub \
|
| 20 |
+
datasets \
|
| 21 |
+
requests \
|
| 22 |
+
pyarrow \
|
| 23 |
+
rustbpe \
|
| 24 |
+
pandas \
|
| 25 |
+
tiktoken \
|
| 26 |
+
pydantic \
|
| 27 |
+
ninja \
|
| 28 |
+
packaging \
|
| 29 |
+
einops
|
| 30 |
+
|
| 31 |
+
# Mamba-3 fused CUDA kernel stack (mandatory — NO fallback allowed).
|
| 32 |
+
#
|
| 33 |
+
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
|
| 34 |
+
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
|
| 35 |
+
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 —
|
| 36 |
+
# nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
|
| 37 |
+
#
|
| 38 |
+
# Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
|
| 39 |
+
# - Python 3.11 (cp311) — matches PyTorch 2.6.0 image
|
| 40 |
+
# - CUDA 12.x wheels (cu12) — matches host CUDA 12.4
|
| 41 |
+
# - PyTorch 2.6 ABI (torch2.6) — exact torch match
|
| 42 |
+
# - cxx11abiFALSE — standard PyTorch pip build
|
| 43 |
+
#
|
| 44 |
+
# Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
|
| 45 |
+
# 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
|
| 46 |
+
# on the Space builder.
|
| 47 |
+
#
|
| 48 |
+
# Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
|
| 49 |
+
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
|
| 50 |
+
RUN pip install \
|
| 51 |
+
'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
|
| 52 |
+
'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
|
| 53 |
+
python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
|
| 54 |
+
|
| 55 |
+
#
|
| 56 |
+
# Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
|
| 57 |
+
# main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
|
| 58 |
+
# files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
|
| 59 |
+
# zero compiled-CUDA dependencies (verified: every import in that subtree is
|
| 60 |
+
# triton/torch/python — no .so files, no nvcc). So we install the v2.3.1 wheel
|
| 61 |
+
# (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
|
| 62 |
+
#
|
| 63 |
+
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
|
| 64 |
+
# missing-file error the smoke hit on the last attempt.
|
| 65 |
+
# Download grafted mamba3 module + triton ops subtree
|
| 66 |
+
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
|
| 67 |
+
BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
|
| 68 |
+
curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
|
| 69 |
+
mkdir -p "$SITE/ops/triton/mamba3" && \
|
| 70 |
+
for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
|
| 71 |
+
curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
|
| 72 |
+
done
|
| 73 |
+
|
| 74 |
+
# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
|
| 75 |
+
# (pure-Triton, works). The shipped __init__.py eagerly imports
|
| 76 |
+
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
|
| 77 |
+
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
|
| 78 |
+
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
|
| 79 |
+
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
|
| 80 |
+
|
| 81 |
+
# Structural check (no triton init — triton has no GPU on the builder)
|
| 82 |
+
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
|
| 83 |
+
test -f "$SITE/modules/mamba3.py" && \
|
| 84 |
+
test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
|
| 85 |
+
test -s "$SITE/__init__.py" && \
|
| 86 |
+
echo "mamba3 graft + __init__ override verified"
|
| 87 |
+
|
| 88 |
+
# Optional tilelang for MIMO path — pure-python, cheap; SISO Mamba3 works without.
|
| 89 |
+
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed — continuing"
|
| 90 |
+
|
| 91 |
+
# Triton version decision: FORCE 3.5.1 — the only version with both mamba3
|
| 92 |
+
# APIs (set_allocator + tl.make_tensor_descriptor). torch 2.6's _inductor
|
| 93 |
+
# imports AttrsDescriptor from triton.compiler.compiler which was removed in
|
| 94 |
+
# triton 3.4+, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
|
| 95 |
+
# before any torch._inductor import path runs, so the incompatibility is
|
| 96 |
+
# neutralized. Build-time assert verifies mamba3's two required APIs.
|
| 97 |
+
RUN pip install --force-reinstall --no-deps 'triton==3.5.1' && \
|
| 98 |
+
python -c "import triton; from triton import language as tl; \
|
| 99 |
+
assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
|
| 100 |
+
assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
|
| 101 |
+
print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"
|
| 102 |
+
|
| 103 |
+
WORKDIR /workspace
|
| 104 |
+
COPY overlay /workspace/feather
|
| 105 |
+
COPY entrypoint.py /app/entrypoint.py
|
| 106 |
+
WORKDIR /workspace/feather
|
| 107 |
+
|
| 108 |
+
RUN python -m py_compile hydra/training.py prepare.py train.py && \
|
| 109 |
+
bash -n scripts/run_domain_expanded_pretrain.sh
|
| 110 |
+
|
| 111 |
+
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
|
| 112 |
+
export HTM_CUDA_ARCH=sm_90 && \
|
| 113 |
+
maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
|
| 114 |
+
pip install htm_rust/target/wheels/htm_rust-*.whl
|
| 115 |
+
|
| 116 |
+
CMD ["python", "/app/entrypoint.py"]
|
entrypoint.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from threading import Thread
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# =============================================================================
|
| 15 |
+
# EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
|
| 16 |
+
# =============================================================================
|
| 17 |
+
# On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet
|
| 18 |
+
# initialized" on first use, because nvidia-fabricmanager on the host
|
| 19 |
+
# synchronizes with the container's first driver call. Once any NVML/CUDA
|
| 20 |
+
# call succeeds once (even just nvidia-smi), the fabric is up for the rest
|
| 21 |
+
# of the container lifetime.
|
| 22 |
+
#
|
| 23 |
+
# Our previous approach (wait in a subprocess before training) didn't work
|
| 24 |
+
# because the "initialization failed" state persisted across calls in the
|
| 25 |
+
# same container. The real fix: kick the driver exactly once with
|
| 26 |
+
# nvidia-smi, which is what successfully-working baseline containers do
|
| 27 |
+
# implicitly via their first torch.cuda call.
|
| 28 |
+
#
|
| 29 |
+
# Must happen BEFORE `import torch` (because any import that eagerly calls
|
| 30 |
+
# cudaGetDeviceCount will cache the Error 802 state).
|
| 31 |
+
def _early_cuda_kick() -> None:
|
| 32 |
+
deadline = time.time() + 120.0
|
| 33 |
+
attempt = 0
|
| 34 |
+
while time.time() < deadline:
|
| 35 |
+
attempt += 1
|
| 36 |
+
r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
|
| 37 |
+
if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \
|
| 38 |
+
or 'A100' in (r.stdout or '') or r.returncode == 0:
|
| 39 |
+
print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True)
|
| 40 |
+
break
|
| 41 |
+
print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
|
| 42 |
+
flush=True)
|
| 43 |
+
time.sleep(2)
|
| 44 |
+
# After nvidia-smi, probe torch in a subprocess so any latent error state
|
| 45 |
+
# doesn't leak into the main process's CUDA context.
|
| 46 |
+
probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
|
| 47 |
+
torch_deadline = time.time() + 120.0
|
| 48 |
+
t_attempt = 0
|
| 49 |
+
while time.time() < torch_deadline:
|
| 50 |
+
t_attempt += 1
|
| 51 |
+
r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
|
| 52 |
+
if r.returncode == 0:
|
| 53 |
+
print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
|
| 54 |
+
return
|
| 55 |
+
if t_attempt == 1:
|
| 56 |
+
print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
|
| 57 |
+
time.sleep(2)
|
| 58 |
+
print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
_early_cuda_kick()
|
| 62 |
+
|
| 63 |
+
# Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
|
| 64 |
+
# triton_cache_setup.py is copied next to this file by the job bash command.
|
| 65 |
+
try:
|
| 66 |
+
import triton_cache_setup as _tcs
|
| 67 |
+
_tcs.setup()
|
| 68 |
+
except ImportError:
|
| 69 |
+
print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
|
| 70 |
+
|
| 71 |
+
from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
|
| 72 |
+
|
| 73 |
+
REPO_ROOT = Path('/workspace/feather')
|
| 74 |
+
CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
|
| 75 |
+
LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
|
| 76 |
+
JOB_ID = os.environ.get('JOB_ID', 'local-job')
|
| 77 |
+
OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
|
| 78 |
+
TOKEN = os.environ.get('HF_TOKEN')
|
| 79 |
+
RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
|
| 80 |
+
APP_PORT = int(os.environ.get('PORT', '7860'))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class _HealthHandler(BaseHTTPRequestHandler):
|
| 84 |
+
def do_GET(self):
|
| 85 |
+
if self.path in ('/', '/health', '/healthz', '/ready'):
|
| 86 |
+
payload = {
|
| 87 |
+
'status': 'ok',
|
| 88 |
+
'mode': RUNTIME_MODE,
|
| 89 |
+
'job_id': JOB_ID,
|
| 90 |
+
}
|
| 91 |
+
body = json.dumps(payload).encode('utf-8')
|
| 92 |
+
self.send_response(200)
|
| 93 |
+
self.send_header('Content-Type', 'application/json')
|
| 94 |
+
self.send_header('Content-Length', str(len(body)))
|
| 95 |
+
self.end_headers()
|
| 96 |
+
self.wfile.write(body)
|
| 97 |
+
return
|
| 98 |
+
self.send_response(404)
|
| 99 |
+
self.end_headers()
|
| 100 |
+
|
| 101 |
+
def log_message(self, format, *args):
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _start_health_server() -> HTTPServer:
|
| 106 |
+
server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
|
| 107 |
+
thread = Thread(target=server.serve_forever, daemon=True)
|
| 108 |
+
thread.start()
|
| 109 |
+
print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
|
| 110 |
+
return server
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
|
| 114 |
+
if not path.exists():
|
| 115 |
+
print(f'[upload] skip missing {path}', flush=True)
|
| 116 |
+
return
|
| 117 |
+
api.upload_file(
|
| 118 |
+
path_or_fileobj=str(path),
|
| 119 |
+
path_in_repo=dest,
|
| 120 |
+
repo_id=OUTPUT_REPO,
|
| 121 |
+
repo_type='model',
|
| 122 |
+
)
|
| 123 |
+
print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
|
| 127 |
+
"""Block until CUDA is fully initialized or timeout.
|
| 128 |
+
|
| 129 |
+
On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
|
| 130 |
+
with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
|
| 131 |
+
(error 802) for the first few seconds, and any import that triggers
|
| 132 |
+
@triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
|
| 133 |
+
"0 active drivers" if it happens during that window.
|
| 134 |
+
|
| 135 |
+
We pre-init CUDA in a throwaway Python subprocess (so any error state does
|
| 136 |
+
not leak into the main training process) and retry until torch.cuda
|
| 137 |
+
reports ready.
|
| 138 |
+
"""
|
| 139 |
+
import time as _t
|
| 140 |
+
probe = (
|
| 141 |
+
"import torch; "
|
| 142 |
+
"import sys; "
|
| 143 |
+
"avail = torch.cuda.is_available(); "
|
| 144 |
+
"count = torch.cuda.device_count() if avail else 0; "
|
| 145 |
+
"sys.exit(0 if (avail and count > 0) else 1)"
|
| 146 |
+
)
|
| 147 |
+
deadline = _t.time() + timeout_s
|
| 148 |
+
attempt = 0
|
| 149 |
+
while _t.time() < deadline:
|
| 150 |
+
attempt += 1
|
| 151 |
+
r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
|
| 152 |
+
if r.returncode == 0:
|
| 153 |
+
print(f'[job] CUDA ready after {attempt} probe(s)', flush=True)
|
| 154 |
+
return
|
| 155 |
+
if attempt == 1:
|
| 156 |
+
print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
|
| 157 |
+
_t.sleep(2)
|
| 158 |
+
print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def run_job_mode() -> int:
|
| 162 |
+
os.chdir(REPO_ROOT)
|
| 163 |
+
os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
|
| 164 |
+
os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
|
| 165 |
+
os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
|
| 166 |
+
os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
|
| 167 |
+
os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
|
| 168 |
+
|
| 169 |
+
# CUDA readiness was kicked at module import via _early_cuda_kick. Keep
|
| 170 |
+
# the wait as a second safety net — no-op if CUDA already ready.
|
| 171 |
+
_wait_for_cuda_ready()
|
| 172 |
+
|
| 173 |
+
cmd = [
|
| 174 |
+
'bash',
|
| 175 |
+
'./scripts/run_domain_expanded_pretrain.sh',
|
| 176 |
+
'--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
|
| 177 |
+
'--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
|
| 178 |
+
]
|
| 179 |
+
print('[job] starting Feather domain-expanded pretrain', flush=True)
|
| 180 |
+
print(f'[job] command={cmd}', flush=True)
|
| 181 |
+
proc = subprocess.run(cmd, check=False)
|
| 182 |
+
|
| 183 |
+
# Push triton compilation cache back to HF Hub for next run.
|
| 184 |
+
try:
|
| 185 |
+
import triton_cache_setup as _tcs
|
| 186 |
+
_tcs.teardown()
|
| 187 |
+
except Exception as _tcs_err:
|
| 188 |
+
print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
|
| 189 |
+
|
| 190 |
+
if TOKEN:
|
| 191 |
+
api = HfApi(token=TOKEN)
|
| 192 |
+
try:
|
| 193 |
+
api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
|
| 196 |
+
prefix = f'jobs/{JOB_ID}'
|
| 197 |
+
try:
|
| 198 |
+
upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
|
| 199 |
+
upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
|
| 200 |
+
upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
|
| 203 |
+
else:
|
| 204 |
+
print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
|
| 205 |
+
|
| 206 |
+
return proc.returncode
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def run_space_mode() -> int:
|
| 210 |
+
server = _start_health_server()
|
| 211 |
+
print('[space] Feather runtime image ready', flush=True)
|
| 212 |
+
try:
|
| 213 |
+
while True:
|
| 214 |
+
time.sleep(3600)
|
| 215 |
+
finally:
|
| 216 |
+
server.shutdown()
|
| 217 |
+
server.server_close()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def main() -> int:
|
| 221 |
+
if RUNTIME_MODE == 'job':
|
| 222 |
+
return run_job_mode()
|
| 223 |
+
return run_space_mode()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == '__main__':
|
| 227 |
+
raise SystemExit(main())
|
mamba_ssm_init.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
|
| 2 |
+
# ABI mismatch with the base image's libtorch.
|
| 3 |
+
#
|
| 4 |
+
# The upstream __init__.py eagerly imports selective_scan_cuda which fails on
|
| 5 |
+
# pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
|
| 6 |
+
# symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
|
| 7 |
+
# all compiled-CUDA imports here and let Mamba3 load directly.
|
| 8 |
+
|
| 9 |
+
__version__ = "2.3.1+feather-graft"
|
| 10 |
+
|
| 11 |
+
# selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
|
| 12 |
+
# by the Feather training path (which is Mamba3-only). If any import path
|
| 13 |
+
# hits this, it will get a clear AttributeError instead of an obscure ImportError.
|
| 14 |
+
selective_scan_fn = None
|
| 15 |
+
mamba_inner_fn = None
|
| 16 |
+
|
| 17 |
+
# --- triton API compatibility shims -----------------------------------------
|
| 18 |
+
# Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
|
| 19 |
+
# imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
|
| 20 |
+
# Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
|
| 21 |
+
# tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
|
| 22 |
+
# satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
|
| 23 |
+
# APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
|
| 24 |
+
# stub is never actually invoked at runtime because the codebase does not use
|
| 25 |
+
# torch.compile — but importing torch._inductor.* still requires the symbol to
|
| 26 |
+
# exist at module load time.
|
| 27 |
+
import triton as _triton # noqa: E402
|
| 28 |
+
if not hasattr(_triton, "set_allocator"):
|
| 29 |
+
def _noop_set_allocator(_fn): # pragma: no cover
|
| 30 |
+
return None
|
| 31 |
+
_triton.set_allocator = _noop_set_allocator
|
| 32 |
+
|
| 33 |
+
import triton.compiler.compiler as _tcc # noqa: E402
|
| 34 |
+
if not hasattr(_tcc, "AttrsDescriptor"):
|
| 35 |
+
class _AttrsDescriptorShim:
|
| 36 |
+
"""Stub for torch._inductor compatibility on triton >= 3.4.
|
| 37 |
+
torch._inductor.runtime.hints imports this at module load but the
|
| 38 |
+
constructor is only called inside torch.compile paths. Accept any
|
| 39 |
+
args/kwargs so the import itself succeeds."""
|
| 40 |
+
def __init__(self, *args, **kwargs):
|
| 41 |
+
self.args = args
|
| 42 |
+
self.kwargs = kwargs
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def from_hints(cls, *args, **kwargs):
|
| 46 |
+
return cls(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
_tcc.AttrsDescriptor = _AttrsDescriptorShim
|
| 49 |
+
|
| 50 |
+
# triton_key: removed in triton 3.5, used by torch._inductor.codecache for
|
| 51 |
+
# FxGraphCache key derivation. Return a stable string so caching still works.
|
| 52 |
+
if not hasattr(_tcc, "triton_key"):
|
| 53 |
+
def _triton_key_shim():
|
| 54 |
+
import triton as _t
|
| 55 |
+
return f"triton-{_t.__version__}-shim"
|
| 56 |
+
_tcc.triton_key = _triton_key_shim
|
| 57 |
+
|
| 58 |
+
# Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
|
| 59 |
+
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
|
| 60 |
+
# so fall back to eager on any dynamo failure rather than crashing. This is
|
| 61 |
+
# defense-in-depth against further triton API drift.
|
| 62 |
+
try:
|
| 63 |
+
import torch._dynamo # noqa: F401 — triggers dynamo module init
|
| 64 |
+
torch._dynamo.config.suppress_errors = True
|
| 65 |
+
except Exception: # pragma: no cover
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
# Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
|
| 69 |
+
from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
|
overlay/htm_rust/src/gpu/fused.rs
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Fused HTM megakernel launcher.
|
| 2 |
+
//!
|
| 3 |
+
//! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
|
| 4 |
+
//! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
|
| 5 |
+
//! the kernel design and the cross-block coherence strategy (grid barrier
|
| 6 |
+
//! via device counter with all blocks concurrently resident).
|
| 7 |
+
//!
|
| 8 |
+
//! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
|
| 9 |
+
//! probes the device SM count at construction and caps grid_dim.x
|
| 10 |
+
//! accordingly — otherwise the grid barrier deadlocks.
|
| 11 |
+
//!
|
| 12 |
+
//! Semantic change from the top-K pipeline: activation is per-column
|
| 13 |
+
//! threshold-based (local lateral inhibition) instead of global top-K.
|
| 14 |
+
//! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
|
| 15 |
+
//! the sparsity target. This is a real architectural change and is
|
| 16 |
+
//! documented in `docs/GPU_HTM.md`.
|
| 17 |
+
|
| 18 |
+
#![cfg(feature = "gpu")]
|
| 19 |
+
|
| 20 |
+
use std::ffi::CString;
|
| 21 |
+
use std::sync::Arc;
|
| 22 |
+
|
| 23 |
+
use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
|
| 24 |
+
LaunchConfig};
|
| 25 |
+
use cudarc::nvrtc::Ptx;
|
| 26 |
+
|
| 27 |
+
use super::sp_gpu::SpatialPoolerGpu;
|
| 28 |
+
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 29 |
+
|
| 30 |
+
const PTX_HTM_FUSED: &str =
|
| 31 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
|
| 32 |
+
|
| 33 |
+
/// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
|
| 34 |
+
///
|
| 35 |
+
/// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
|
| 36 |
+
/// C-side `FusedPtrs` still has the field at the same byte offset; removing
|
| 37 |
+
/// it here would shift all subsequent fields and break the layout. Worker A
|
| 38 |
+
/// will eventually delete the field from both sides once the kernel is
|
| 39 |
+
/// updated; until then we zero it.
|
| 40 |
+
#[repr(C)]
|
| 41 |
+
#[derive(Clone, Copy)]
|
| 42 |
+
pub struct FusedPtrs {
|
| 43 |
+
pub syn_bit: u64,
|
| 44 |
+
pub syn_perm: u64,
|
| 45 |
+
pub boost: u64,
|
| 46 |
+
pub active_duty: u64,
|
| 47 |
+
pub inhibition_threshold: u64,
|
| 48 |
+
pub seg_cell_id: u64,
|
| 49 |
+
pub seg_syn_count: u64,
|
| 50 |
+
pub syn_presyn: u64,
|
| 51 |
+
pub tm_syn_perm: u64,
|
| 52 |
+
pub cell_seg_count: u64,
|
| 53 |
+
pub cell_active_a: u64,
|
| 54 |
+
pub cell_active_b: u64,
|
| 55 |
+
pub cell_winner_a: u64,
|
| 56 |
+
pub cell_winner_b: u64,
|
| 57 |
+
pub inputs: u64,
|
| 58 |
+
pub cols_out: u64,
|
| 59 |
+
pub anom_out: u64,
|
| 60 |
+
/// ABI-compat dummy — always 0. No device memory is allocated for this
|
| 61 |
+
/// field; the cluster barrier replaces the old software DLB barrier.
|
| 62 |
+
pub barrier_counters: u64,
|
| 63 |
+
pub step_scratch: u64,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
unsafe impl DeviceRepr for FusedPtrs {}
|
| 67 |
+
|
| 68 |
+
/// Launch-time config — matches C-side `FusedConfig` 1:1.
|
| 69 |
+
#[repr(C)]
|
| 70 |
+
#[derive(Clone, Copy)]
|
| 71 |
+
pub struct FusedConfig {
|
| 72 |
+
pub input_bits: u32,
|
| 73 |
+
pub n_columns: u32,
|
| 74 |
+
pub synapses_per_col: u32,
|
| 75 |
+
pub conn_thr: f32,
|
| 76 |
+
pub sp_inc: f32,
|
| 77 |
+
pub sp_dec: f32,
|
| 78 |
+
pub sparsity_target: f32,
|
| 79 |
+
pub duty_alpha: f32,
|
| 80 |
+
pub thr_adapt_rate: f32,
|
| 81 |
+
pub cells_per_column: u32,
|
| 82 |
+
pub n_cells: u32,
|
| 83 |
+
pub bits_words: u32,
|
| 84 |
+
pub max_segments_per_cell: u32,
|
| 85 |
+
pub synapses_per_segment: u32,
|
| 86 |
+
pub activation_threshold: u32,
|
| 87 |
+
pub learning_threshold: u32,
|
| 88 |
+
pub max_new_synapses: u32,
|
| 89 |
+
pub conn_thr_i16: i32,
|
| 90 |
+
pub perm_inc_i16: i32,
|
| 91 |
+
pub perm_dec_i16: i32,
|
| 92 |
+
pub predicted_seg_dec_i16: i32,
|
| 93 |
+
pub initial_perm_i16: i32,
|
| 94 |
+
pub t: u32,
|
| 95 |
+
pub learn: u32,
|
| 96 |
+
pub iter_seed: u32,
|
| 97 |
+
pub cooperative_grid_sync: u32,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
unsafe impl DeviceRepr for FusedConfig {}
|
| 101 |
+
|
| 102 |
+
/// Cluster launch parameters probed at construction time.
|
| 103 |
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
| 104 |
+
pub(crate) struct ClusterInfo {
|
| 105 |
+
/// Maximum cluster size supported by this device (0 = cluster unsupported).
|
| 106 |
+
pub max_cluster_size: u32,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// There is only ONE launch mode: non-cooperative launch with Hopper Thread
|
| 110 |
+
// Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
|
| 111 |
+
// software DLB barrier and the cooperative-launch path are both removed.
|
| 112 |
+
// Cluster barriers replace both.
|
| 113 |
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
| 114 |
+
pub(crate) struct FusedLaunchPlan {
|
| 115 |
+
pub grid_dim_x: u32,
|
| 116 |
+
pub block_dim_x: u32,
|
| 117 |
+
pub cooperative_grid_limit: u32,
|
| 118 |
+
pub sm_count: u32,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
fn fused_grid_cap_override() -> Option<u32> {
|
| 122 |
+
std::env::var("HTM_FUSED_GRID_CAP")
|
| 123 |
+
.ok()
|
| 124 |
+
.and_then(|s| s.parse::<u32>().ok())
|
| 125 |
+
.map(|v| v.max(1))
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
pub(crate) fn plan_fused_launch(
|
| 129 |
+
sm_count: u32,
|
| 130 |
+
cooperative_supported: bool,
|
| 131 |
+
cooperative_grid_limit: u32,
|
| 132 |
+
grid_cap_override: Option<u32>,
|
| 133 |
+
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
+
let sm_count = sm_count.max(1);
|
| 135 |
+
let block_dim_x = 1024u32;
|
| 136 |
+
|
| 137 |
+
// Cluster launch path: cooperative launch is not required. Keep the probe
|
| 138 |
+
// result for residency estimation only.
|
| 139 |
+
if !cooperative_supported {
|
| 140 |
+
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
// Cluster constraint: grid_dim_x must equal the cluster size (16) so that
|
| 144 |
+
// each region maps to exactly one cluster. `HTM_FUSED_GRID_CAP` can lower
|
| 145 |
+
// this for debugging but should not exceed 16 for cluster correctness.
|
| 146 |
+
let default_grid_cap = 16u32;
|
| 147 |
+
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap).min(16);
|
| 148 |
+
let resident_bound = if cooperative_grid_limit > 0 {
|
| 149 |
+
cooperative_grid_limit.max(sm_count * 2)
|
| 150 |
+
} else {
|
| 151 |
+
sm_count * 2
|
| 152 |
+
};
|
| 153 |
+
Ok(FusedLaunchPlan {
|
| 154 |
+
grid_dim_x: resident_bound.min(grid_cap).max(1),
|
| 155 |
+
block_dim_x,
|
| 156 |
+
cooperative_grid_limit: resident_bound,
|
| 157 |
+
sm_count,
|
| 158 |
+
})
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
pub(super) struct RawFusedKernel {
|
| 162 |
+
module: sys::CUmodule,
|
| 163 |
+
pub(super) function: sys::CUfunction,
|
| 164 |
+
pub(super) function_batched: sys::CUfunction,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
unsafe impl Send for RawFusedKernel {}
|
| 168 |
+
unsafe impl Sync for RawFusedKernel {}
|
| 169 |
+
|
| 170 |
+
impl Drop for RawFusedKernel {
|
| 171 |
+
fn drop(&mut self) {
|
| 172 |
+
unsafe {
|
| 173 |
+
let _ = result::module::unload(self.module);
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// Owns fused-path-only device state:
|
| 179 |
+
/// - per-column inhibition threshold (replaces global top-K)
|
| 180 |
+
/// - ping-pong cell_active/cell_winner bitsets
|
| 181 |
+
/// - step_scratch (n_active, n_unpred per timestep)
|
| 182 |
+
/// - cluster launch capability info
|
| 183 |
+
pub struct FusedState {
|
| 184 |
+
dev: Arc<CudaDevice>,
|
| 185 |
+
pub(super) raw_kernel: RawFusedKernel,
|
| 186 |
+
|
| 187 |
+
pub inhibition_threshold: CudaSlice<f32>,
|
| 188 |
+
pub cell_active_bits_a: CudaSlice<u32>,
|
| 189 |
+
pub cell_active_bits_b: CudaSlice<u32>,
|
| 190 |
+
pub cell_winner_bits_a: CudaSlice<u32>,
|
| 191 |
+
pub cell_winner_bits_b: CudaSlice<u32>,
|
| 192 |
+
pub step_scratch: CudaSlice<u32>, // length 6
|
| 193 |
+
|
| 194 |
+
pub grid_dim_x: u32,
|
| 195 |
+
pub block_dim_x: u32,
|
| 196 |
+
pub cooperative_grid_limit: u32,
|
| 197 |
+
pub iter_counter: u32,
|
| 198 |
+
|
| 199 |
+
/// Hopper cluster launch capability (0 = unsupported).
|
| 200 |
+
pub cluster_info: ClusterInfo,
|
| 201 |
+
|
| 202 |
+
// Config mirror (read-only after init).
|
| 203 |
+
#[allow(dead_code)]
|
| 204 |
+
pub initial_threshold: f32,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
impl FusedState {
|
| 208 |
+
pub fn new(
|
| 209 |
+
dev: Arc<CudaDevice>,
|
| 210 |
+
n_columns: usize,
|
| 211 |
+
cells_per_column: usize,
|
| 212 |
+
initial_threshold: f32,
|
| 213 |
+
) -> Result<Self, DriverError> {
|
| 214 |
+
let n_cells = n_columns * cells_per_column;
|
| 215 |
+
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
|
| 216 |
+
let bits_words = n_cells / 32;
|
| 217 |
+
|
| 218 |
+
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
|
| 219 |
+
let init_vec = vec![initial_threshold; n_columns];
|
| 220 |
+
dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?;
|
| 221 |
+
|
| 222 |
+
let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
|
| 223 |
+
let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
|
| 224 |
+
let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
|
| 225 |
+
let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
|
| 226 |
+
let step_scratch = dev.alloc_zeros::<u32>(6)?;
|
| 227 |
+
|
| 228 |
+
unsafe {
|
| 229 |
+
result::ctx::set_current(*dev.cu_primary_ctx())?;
|
| 230 |
+
}
|
| 231 |
+
if dev.get_func("htm_fused", "htm_fused_step").is_none() {
|
| 232 |
+
dev.load_ptx(
|
| 233 |
+
Ptx::from_src(PTX_HTM_FUSED),
|
| 234 |
+
"htm_fused",
|
| 235 |
+
&["htm_fused_step", "htm_fused_step_batched"],
|
| 236 |
+
)?;
|
| 237 |
+
}
|
| 238 |
+
let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes");
|
| 239 |
+
let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?;
|
| 240 |
+
let function = unsafe {
|
| 241 |
+
result::module::get_function(module, CString::new("htm_fused_step").unwrap())
|
| 242 |
+
}?;
|
| 243 |
+
let function_batched = unsafe {
|
| 244 |
+
result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap())
|
| 245 |
+
}?;
|
| 246 |
+
|
| 247 |
+
// Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in).
|
| 248 |
+
// Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on
|
| 249 |
+
// every launched kernel function, otherwise cuLaunchKernelEx rejects
|
| 250 |
+
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
|
| 251 |
+
unsafe {
|
| 252 |
+
let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
|
| 253 |
+
// Ignore errors: older CUDA may lack the attribute, in which case
|
| 254 |
+
// only portable sizes (<= 8) work — plan_fused_launch caps at 8.
|
| 255 |
+
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
|
| 256 |
+
let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Probe SM count.
|
| 260 |
+
let sm_count = match dev.attribute(
|
| 261 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
| 262 |
+
) {
|
| 263 |
+
Ok(v) => v as u32,
|
| 264 |
+
Err(_) => 16u32,
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
// T1: Probe Hopper cluster launch capability.
|
| 268 |
+
let max_cluster_size = match dev.attribute(
|
| 269 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
|
| 270 |
+
) {
|
| 271 |
+
Ok(v) if v > 0 => {
|
| 272 |
+
// H200/sm_90a supports up to 16 blocks per cluster.
|
| 273 |
+
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
|
| 274 |
+
// Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster).
|
| 275 |
+
16u32
|
| 276 |
+
}
|
| 277 |
+
_ => 0u32,
|
| 278 |
+
};
|
| 279 |
+
eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
|
| 280 |
+
let cluster_info = ClusterInfo { max_cluster_size };
|
| 281 |
+
|
| 282 |
+
let cooperative_supported = matches!(
|
| 283 |
+
dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
|
| 284 |
+
Ok(v) if v > 0
|
| 285 |
+
);
|
| 286 |
+
let cooperative_grid_limit = if cooperative_supported {
|
| 287 |
+
let blocks_per_sm = unsafe {
|
| 288 |
+
result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0)
|
| 289 |
+
}
|
| 290 |
+
.ok()
|
| 291 |
+
.map(|v| v.max(0) as u32)
|
| 292 |
+
.unwrap_or(0);
|
| 293 |
+
sm_count.saturating_mul(blocks_per_sm)
|
| 294 |
+
} else {
|
| 295 |
+
0
|
| 296 |
+
};
|
| 297 |
+
let launch_plan = plan_fused_launch(
|
| 298 |
+
sm_count,
|
| 299 |
+
cooperative_supported,
|
| 300 |
+
cooperative_grid_limit,
|
| 301 |
+
fused_grid_cap_override(),
|
| 302 |
+
)
|
| 303 |
+
.map_err(|msg| {
|
| 304 |
+
// Surface as a CUDA-ish error so callers can propagate.
|
| 305 |
+
eprintln!("[htm_rust] FATAL: {msg}");
|
| 306 |
+
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
|
| 307 |
+
})?;
|
| 308 |
+
|
| 309 |
+
eprintln!(
|
| 310 |
+
"[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
|
| 311 |
+
launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
|
| 312 |
+
cluster_info.max_cluster_size,
|
| 313 |
+
);
|
| 314 |
+
|
| 315 |
+
Ok(Self {
|
| 316 |
+
dev,
|
| 317 |
+
raw_kernel: RawFusedKernel { module, function, function_batched },
|
| 318 |
+
inhibition_threshold,
|
| 319 |
+
cell_active_bits_a,
|
| 320 |
+
cell_active_bits_b,
|
| 321 |
+
cell_winner_bits_a,
|
| 322 |
+
cell_winner_bits_b,
|
| 323 |
+
step_scratch,
|
| 324 |
+
grid_dim_x: launch_plan.grid_dim_x,
|
| 325 |
+
block_dim_x: launch_plan.block_dim_x,
|
| 326 |
+
cooperative_grid_limit: launch_plan.cooperative_grid_limit,
|
| 327 |
+
iter_counter: 0,
|
| 328 |
+
cluster_info,
|
| 329 |
+
initial_threshold,
|
| 330 |
+
})
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
/// Reset fused state. Called at region.reset().
|
| 334 |
+
pub fn reset(&mut self) -> Result<(), DriverError> {
|
| 335 |
+
self.dev.memset_zeros(&mut self.cell_active_bits_a)?;
|
| 336 |
+
self.dev.memset_zeros(&mut self.cell_active_bits_b)?;
|
| 337 |
+
self.dev.memset_zeros(&mut self.cell_winner_bits_a)?;
|
| 338 |
+
self.dev.memset_zeros(&mut self.cell_winner_bits_b)?;
|
| 339 |
+
self.dev.memset_zeros(&mut self.step_scratch)?;
|
| 340 |
+
// Do NOT reset inhibition_threshold — it's learned state. A hard
|
| 341 |
+
// reset of TM state should NOT forget the sparsity calibration.
|
| 342 |
+
Ok(())
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/// Launch the fused megakernel. Processes all T timesteps in one kernel.
|
| 347 |
+
///
|
| 348 |
+
/// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)`
|
| 349 |
+
/// when the device supports cluster launch, otherwise falls back to a plain
|
| 350 |
+
/// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the
|
| 351 |
+
/// entire grid fits in one cluster.
|
| 352 |
+
#[allow(clippy::too_many_arguments)]
|
| 353 |
+
pub fn launch_fused(
|
| 354 |
+
sp: &mut SpatialPoolerGpu,
|
| 355 |
+
tm: &mut TemporalMemoryGpu,
|
| 356 |
+
fused: &mut FusedState,
|
| 357 |
+
inputs_flat: &CudaSlice<u8>,
|
| 358 |
+
cols_out: &mut CudaSlice<u8>,
|
| 359 |
+
anom_out: &mut CudaSlice<f32>,
|
| 360 |
+
t: usize,
|
| 361 |
+
input_bits: usize,
|
| 362 |
+
learn: bool,
|
| 363 |
+
) -> Result<(), DriverError> {
|
| 364 |
+
// Reset step_scratch before each launch (safe re-entry).
|
| 365 |
+
sp.dev_ref().memset_zeros(&mut fused.step_scratch)?;
|
| 366 |
+
|
| 367 |
+
fused.iter_counter = fused.iter_counter.wrapping_add(1);
|
| 368 |
+
|
| 369 |
+
let cfg = FusedConfig {
|
| 370 |
+
input_bits: input_bits as u32,
|
| 371 |
+
n_columns: sp.n_columns_accessor() as u32,
|
| 372 |
+
synapses_per_col: sp.synapses_per_col_accessor() as u32,
|
| 373 |
+
conn_thr: sp.conn_thr_accessor(),
|
| 374 |
+
sp_inc: sp.inc_accessor(),
|
| 375 |
+
sp_dec: sp.dec_accessor(),
|
| 376 |
+
sparsity_target: sp.sparsity_accessor(),
|
| 377 |
+
duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0),
|
| 378 |
+
thr_adapt_rate: 0.001f32,
|
| 379 |
+
cells_per_column: tm.cells_per_column as u32,
|
| 380 |
+
n_cells: tm.n_cells as u32,
|
| 381 |
+
bits_words: tm.bits_words as u32,
|
| 382 |
+
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 383 |
+
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 384 |
+
activation_threshold: tm.activation_threshold,
|
| 385 |
+
learning_threshold: tm.learning_threshold,
|
| 386 |
+
max_new_synapses: tm.max_new_synapse_count,
|
| 387 |
+
conn_thr_i16: tm.conn_thr_i16 as i32,
|
| 388 |
+
perm_inc_i16: tm.perm_inc_i16 as i32,
|
| 389 |
+
perm_dec_i16: tm.perm_dec_i16 as i32,
|
| 390 |
+
predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32,
|
| 391 |
+
initial_perm_i16: tm.initial_perm_i16 as i32,
|
| 392 |
+
t: t as u32,
|
| 393 |
+
learn: if learn { 1 } else { 0 },
|
| 394 |
+
iter_seed: fused.iter_counter,
|
| 395 |
+
cooperative_grid_sync: 1,
|
| 396 |
+
};
|
| 397 |
+
|
| 398 |
+
let ptrs = FusedPtrs {
|
| 399 |
+
syn_bit: *sp.syn_bit_accessor().device_ptr(),
|
| 400 |
+
syn_perm: *sp.syn_perm_accessor().device_ptr(),
|
| 401 |
+
boost: *sp.boost_accessor().device_ptr(),
|
| 402 |
+
active_duty: *sp.active_duty_accessor().device_ptr(),
|
| 403 |
+
inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
|
| 404 |
+
seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(),
|
| 405 |
+
seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(),
|
| 406 |
+
syn_presyn: *tm.syn_presyn_accessor().device_ptr(),
|
| 407 |
+
tm_syn_perm: *tm.syn_perm_accessor().device_ptr(),
|
| 408 |
+
cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(),
|
| 409 |
+
cell_active_a: *fused.cell_active_bits_a.device_ptr(),
|
| 410 |
+
cell_active_b: *fused.cell_active_bits_b.device_ptr(),
|
| 411 |
+
cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
|
| 412 |
+
cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
|
| 413 |
+
inputs: *inputs_flat.device_ptr(),
|
| 414 |
+
cols_out: *cols_out.device_ptr(),
|
| 415 |
+
anom_out: *anom_out.device_ptr(),
|
| 416 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 417 |
+
step_scratch: *fused.step_scratch.device_ptr(),
|
| 418 |
+
};
|
| 419 |
+
|
| 420 |
+
let grid_x = fused.grid_dim_x;
|
| 421 |
+
let block_x = fused.block_dim_x;
|
| 422 |
+
let cu_stream = *sp.dev_ref().cu_stream();
|
| 423 |
+
let use_cluster = fused.cluster_info.max_cluster_size > 0;
|
| 424 |
+
|
| 425 |
+
unsafe {
|
| 426 |
+
result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?;
|
| 427 |
+
let mut kernel_params: [*mut std::ffi::c_void; 2] = [
|
| 428 |
+
(&ptrs as *const FusedPtrs).cast_mut().cast(),
|
| 429 |
+
(&cfg as *const FusedConfig).cast_mut().cast(),
|
| 430 |
+
];
|
| 431 |
+
|
| 432 |
+
if use_cluster {
|
| 433 |
+
// T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION.
|
| 434 |
+
// cluster_dim=(16,1,1) maps the entire single-region grid into one cluster.
|
| 435 |
+
let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
|
| 436 |
+
attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
| 437 |
+
attr.value.clusterDim.x = 16;
|
| 438 |
+
attr.value.clusterDim.y = 1;
|
| 439 |
+
attr.value.clusterDim.z = 1;
|
| 440 |
+
|
| 441 |
+
let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
|
| 442 |
+
launch_cfg.gridDimX = grid_x;
|
| 443 |
+
launch_cfg.gridDimY = 1;
|
| 444 |
+
launch_cfg.gridDimZ = 1;
|
| 445 |
+
launch_cfg.blockDimX = block_x;
|
| 446 |
+
launch_cfg.blockDimY = 1;
|
| 447 |
+
launch_cfg.blockDimZ = 1;
|
| 448 |
+
launch_cfg.sharedMemBytes = 0;
|
| 449 |
+
launch_cfg.hStream = cu_stream;
|
| 450 |
+
launch_cfg.numAttrs = 1;
|
| 451 |
+
launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
|
| 452 |
+
|
| 453 |
+
let ret = sys::lib().cuLaunchKernelEx(
|
| 454 |
+
&launch_cfg as *const sys::CUlaunchConfig,
|
| 455 |
+
fused.raw_kernel.function,
|
| 456 |
+
kernel_params.as_mut_ptr(),
|
| 457 |
+
std::ptr::null_mut(),
|
| 458 |
+
);
|
| 459 |
+
if ret != sys::CUresult::CUDA_SUCCESS {
|
| 460 |
+
return Err(DriverError(ret));
|
| 461 |
+
}
|
| 462 |
+
} else {
|
| 463 |
+
// Fallback for devices that don't support cluster launch.
|
| 464 |
+
result::launch_kernel(
|
| 465 |
+
fused.raw_kernel.function,
|
| 466 |
+
(grid_x, 1, 1),
|
| 467 |
+
(block_x, 1, 1),
|
| 468 |
+
0,
|
| 469 |
+
cu_stream,
|
| 470 |
+
&mut kernel_params,
|
| 471 |
+
)?;
|
| 472 |
+
}
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
Ok(())
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
/// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel
|
| 479 |
+
/// body; each block reads its region's FusedPtrs from a device-side array
|
| 480 |
+
/// indexed by blockIdx.y. All regions share the same config (same
|
| 481 |
+
/// input_bits/n_columns/etc.) so we pass one FusedConfig.
|
| 482 |
+
///
|
| 483 |
+
/// This breaks through the CUDA cooperative-kernel device-level
|
| 484 |
+
/// serialization: multiple cooperative launches are serialized regardless
|
| 485 |
+
/// of stream, but one cooperative launch with grid.y=B processes all
|
| 486 |
+
/// regions in a single invocation — ~B× speedup vs B sequential launches.
|
| 487 |
+
#[allow(clippy::too_many_arguments)]
|
| 488 |
+
/// Low-level raw-pointer entry, called by PyO3 binding which holds the
|
| 489 |
+
/// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live,
|
| 490 |
+
/// uniquely-borrowed region. All regions must be distinct.
|
| 491 |
+
pub(super) fn launch_fused_batched_raw(
|
| 492 |
+
region_ptrs: &[*mut super::HTMRegionGpu],
|
| 493 |
+
inputs_per_region: &[u64],
|
| 494 |
+
cols_per_region: &[u64],
|
| 495 |
+
anom_per_region: &[u64],
|
| 496 |
+
t: usize,
|
| 497 |
+
input_bits: usize,
|
| 498 |
+
learn: bool,
|
| 499 |
+
) -> Result<(), DriverError> {
|
| 500 |
+
let b = region_ptrs.len();
|
| 501 |
+
assert_eq!(inputs_per_region.len(), b);
|
| 502 |
+
assert_eq!(cols_per_region.len(), b);
|
| 503 |
+
assert_eq!(anom_per_region.len(), b);
|
| 504 |
+
assert!(b >= 1, "need at least one region");
|
| 505 |
+
|
| 506 |
+
// Reset per-region step_scratch before each launch.
|
| 507 |
+
for &rp in region_ptrs.iter() {
|
| 508 |
+
let r = unsafe { &mut *rp };
|
| 509 |
+
let dev = r.sp_gpu.dev_ref().clone();
|
| 510 |
+
dev.memset_zeros(&mut r.fused_state.step_scratch)?;
|
| 511 |
+
r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1);
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
// Shared config — all regions use identical sp/tm parameters.
|
| 515 |
+
let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
|
| 516 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 517 |
+
(
|
| 518 |
+
r0.fused_state.grid_dim_x,
|
| 519 |
+
r0.fused_state.block_dim_x,
|
| 520 |
+
r0.fused_state.raw_kernel.function_batched,
|
| 521 |
+
*r0.sp_gpu.dev_ref().cu_stream(),
|
| 522 |
+
*r0.sp_gpu.dev_ref().cu_primary_ctx(),
|
| 523 |
+
)
|
| 524 |
+
};
|
| 525 |
+
|
| 526 |
+
let cfg = {
|
| 527 |
+
let r = unsafe { &*region_ptrs[0] };
|
| 528 |
+
FusedConfig {
|
| 529 |
+
input_bits: input_bits as u32,
|
| 530 |
+
n_columns: r.sp_gpu.n_columns_accessor() as u32,
|
| 531 |
+
synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
|
| 532 |
+
conn_thr: r.sp_gpu.conn_thr_accessor(),
|
| 533 |
+
sp_inc: r.sp_gpu.inc_accessor(),
|
| 534 |
+
sp_dec: r.sp_gpu.dec_accessor(),
|
| 535 |
+
sparsity_target: r.sp_gpu.sparsity_accessor(),
|
| 536 |
+
duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0),
|
| 537 |
+
thr_adapt_rate: 0.001f32,
|
| 538 |
+
cells_per_column: r.tm_gpu.cells_per_column as u32,
|
| 539 |
+
n_cells: r.tm_gpu.n_cells as u32,
|
| 540 |
+
bits_words: r.tm_gpu.bits_words as u32,
|
| 541 |
+
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 542 |
+
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 543 |
+
activation_threshold: r.tm_gpu.activation_threshold,
|
| 544 |
+
learning_threshold: r.tm_gpu.learning_threshold,
|
| 545 |
+
max_new_synapses: r.tm_gpu.max_new_synapse_count,
|
| 546 |
+
conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32,
|
| 547 |
+
perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32,
|
| 548 |
+
perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32,
|
| 549 |
+
predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32,
|
| 550 |
+
initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
|
| 551 |
+
t: t as u32,
|
| 552 |
+
learn: if learn { 1 } else { 0 },
|
| 553 |
+
iter_seed: r.fused_state.iter_counter,
|
| 554 |
+
cooperative_grid_sync: 1,
|
| 555 |
+
}
|
| 556 |
+
};
|
| 557 |
+
|
| 558 |
+
// Build B FusedPtrs per-region.
|
| 559 |
+
let ptrs_vec: Vec<FusedPtrs> = (0..b)
|
| 560 |
+
.map(|i| {
|
| 561 |
+
let r = unsafe { &*region_ptrs[i] };
|
| 562 |
+
FusedPtrs {
|
| 563 |
+
syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
|
| 564 |
+
syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
|
| 565 |
+
boost: *r.sp_gpu.boost_accessor().device_ptr(),
|
| 566 |
+
active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
|
| 567 |
+
inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(),
|
| 568 |
+
seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
|
| 569 |
+
seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
|
| 570 |
+
syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
|
| 571 |
+
tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
|
| 572 |
+
cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
|
| 573 |
+
cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(),
|
| 574 |
+
cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(),
|
| 575 |
+
cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(),
|
| 576 |
+
cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(),
|
| 577 |
+
inputs: inputs_per_region[i],
|
| 578 |
+
cols_out: cols_per_region[i],
|
| 579 |
+
anom_out: anom_per_region[i],
|
| 580 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 581 |
+
step_scratch: *r.fused_state.step_scratch.device_ptr(),
|
| 582 |
+
}
|
| 583 |
+
})
|
| 584 |
+
.collect();
|
| 585 |
+
|
| 586 |
+
// Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
|
| 587 |
+
// FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
|
| 588 |
+
let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone();
|
| 589 |
+
let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?;
|
| 590 |
+
let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr();
|
| 591 |
+
|
| 592 |
+
// T10: Cluster launch for batched regions.
|
| 593 |
+
// Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
|
| 594 |
+
// occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
|
| 595 |
+
// on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
|
| 596 |
+
let use_cluster = {
|
| 597 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 598 |
+
r0.fused_state.cluster_info.max_cluster_size > 0
|
| 599 |
+
};
|
| 600 |
+
|
| 601 |
+
unsafe {
|
| 602 |
+
result::ctx::set_current(cu_ctx)?;
|
| 603 |
+
let mut kernel_params: [*mut std::ffi::c_void; 2] = [
|
| 604 |
+
(&ptrs_dev_ptr as *const u64).cast_mut().cast(),
|
| 605 |
+
(&cfg as *const FusedConfig).cast_mut().cast(),
|
| 606 |
+
];
|
| 607 |
+
|
| 608 |
+
if use_cluster {
|
| 609 |
+
let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
|
| 610 |
+
attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
| 611 |
+
attr.value.clusterDim.x = 16;
|
| 612 |
+
attr.value.clusterDim.y = 1;
|
| 613 |
+
attr.value.clusterDim.z = 1;
|
| 614 |
+
|
| 615 |
+
let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
|
| 616 |
+
launch_cfg.gridDimX = grid_x;
|
| 617 |
+
launch_cfg.gridDimY = b as u32;
|
| 618 |
+
launch_cfg.gridDimZ = 1;
|
| 619 |
+
launch_cfg.blockDimX = block_x;
|
| 620 |
+
launch_cfg.blockDimY = 1;
|
| 621 |
+
launch_cfg.blockDimZ = 1;
|
| 622 |
+
launch_cfg.sharedMemBytes = 0;
|
| 623 |
+
launch_cfg.hStream = cu_stream;
|
| 624 |
+
launch_cfg.numAttrs = 1;
|
| 625 |
+
launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
|
| 626 |
+
|
| 627 |
+
let ret = sys::lib().cuLaunchKernelEx(
|
| 628 |
+
&launch_cfg as *const sys::CUlaunchConfig,
|
| 629 |
+
function_batched,
|
| 630 |
+
kernel_params.as_mut_ptr(),
|
| 631 |
+
std::ptr::null_mut(),
|
| 632 |
+
);
|
| 633 |
+
if ret != sys::CUresult::CUDA_SUCCESS {
|
| 634 |
+
return Err(DriverError(ret));
|
| 635 |
+
}
|
| 636 |
+
} else {
|
| 637 |
+
// Fallback: plain non-cooperative launch for non-Hopper devices.
|
| 638 |
+
result::launch_kernel(
|
| 639 |
+
function_batched,
|
| 640 |
+
(grid_x, b as u32, 1),
|
| 641 |
+
(block_x, 1, 1),
|
| 642 |
+
0,
|
| 643 |
+
cu_stream,
|
| 644 |
+
&mut kernel_params,
|
| 645 |
+
)?;
|
| 646 |
+
}
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
Ok(())
|
| 650 |
+
}
|
overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Fused HTM megakernel — SP + TM, all T timesteps in a single launch.
|
| 2 |
+
//
|
| 3 |
+
// Design rationale:
|
| 4 |
+
// - Global top-K column selection requires cross-block synchronization at
|
| 5 |
+
// every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
|
| 6 |
+
// - Replace with per-column threshold activation using local lateral
|
| 7 |
+
// inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
|
| 8 |
+
// Threshold is a per-column running-EMA learned scalar that steers the
|
| 9 |
+
// column's long-run activation rate toward the global sparsity target.
|
| 10 |
+
// - This is biologically grounded (GABAergic local inhibition) and supported
|
| 11 |
+
// by HTM theory (duty-cycle boost already drives this loop; we just
|
| 12 |
+
// change which lever the EMA pulls).
|
| 13 |
+
//
|
| 14 |
+
// Launch shape:
|
| 15 |
+
// grid = min(device SM count, 16) // hard cap — see below
|
| 16 |
+
// block = 1024 threads = 32 warps
|
| 17 |
+
// Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
|
| 18 |
+
//
|
| 19 |
+
// Cross-block coherence:
|
| 20 |
+
// - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
|
| 21 |
+
// read _b; reversed at odd t.
|
| 22 |
+
// - Preferred path: cooperative launch + hardware whole-grid sync.
|
| 23 |
+
// - Fallback path: software 3-slot rotating grid barrier for devices/drivers
|
| 24 |
+
// that cannot do cooperative launch.
|
| 25 |
+
//
|
| 26 |
+
// 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
|
| 27 |
+
// cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
|
| 28 |
+
// 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
|
| 29 |
+
// leaving scheduled blocks spinning on the software grid barrier waiting for
|
| 30 |
+
// peer blocks that would never run. 16 blocks is below any realistic residency
|
| 31 |
+
// floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
|
| 32 |
+
// memory bandwidth on the spatial-pooler stage.
|
| 33 |
+
//
|
| 34 |
+
// Kernel signature uses struct-by-value for pointers and config to stay
|
| 35 |
+
// inside cudarc's launch-arg count limit.
|
| 36 |
+
|
| 37 |
+
#include <cooperative_groups.h>
|
| 38 |
+
#include <cooperative_groups/memcpy_async.h>
|
| 39 |
+
|
| 40 |
+
namespace cg = cooperative_groups;
|
| 41 |
+
|
| 42 |
+
// Maximum columns owned per cluster-block in DSMEM.
|
| 43 |
+
// Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
|
| 44 |
+
// At cluster_size=16: supports up to 256*16=4096 columns.
|
| 45 |
+
// Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM —
|
| 46 |
+
// well under the 228 KB H200 shared-memory cap.
|
| 47 |
+
#define COLS_PER_CLUSTER_BLOCK_MAX 256u
|
| 48 |
+
|
| 49 |
+
// Maximum input_bits supported by the TMA-multicast staging tile.
|
| 50 |
+
// At 32 KB this covers the production SDR width (16384 bits) with 2× headroom.
|
| 51 |
+
// Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB —
|
| 52 |
+
// well under the 228 KB H200 limit.
|
| 53 |
+
//
|
| 54 |
+
// Expected speedup from TMA multicast input staging (T9/T11):
|
| 55 |
+
// - Without staging: 16 SMs × T × (input_bits GMEM reads per timestep)
|
| 56 |
+
// - With staging: 1 TMA DMA per timestep, shared reads from L1 thereafter
|
| 57 |
+
// - Theoretical DRAM bandwidth reduction: ~16× on input reads
|
| 58 |
+
// - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
|
| 59 |
+
#define INPUT_BITS_MAX 32768u
|
| 60 |
+
|
| 61 |
+
extern "C" {
|
| 62 |
+
|
| 63 |
+
struct FusedPtrs {
|
| 64 |
+
unsigned long long syn_bit;
|
| 65 |
+
unsigned long long syn_perm;
|
| 66 |
+
unsigned long long boost;
|
| 67 |
+
unsigned long long active_duty;
|
| 68 |
+
unsigned long long inhibition_threshold;
|
| 69 |
+
unsigned long long seg_cell_id;
|
| 70 |
+
unsigned long long seg_syn_count;
|
| 71 |
+
unsigned long long syn_presyn;
|
| 72 |
+
unsigned long long tm_syn_perm;
|
| 73 |
+
unsigned long long cell_seg_count;
|
| 74 |
+
unsigned long long cell_active_a;
|
| 75 |
+
unsigned long long cell_active_b;
|
| 76 |
+
unsigned long long cell_winner_a;
|
| 77 |
+
unsigned long long cell_winner_b;
|
| 78 |
+
unsigned long long inputs;
|
| 79 |
+
unsigned long long cols_out;
|
| 80 |
+
unsigned long long anom_out;
|
| 81 |
+
unsigned long long barrier_counters;
|
| 82 |
+
unsigned long long step_scratch;
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
struct FusedConfig {
|
| 86 |
+
// SP constants
|
| 87 |
+
unsigned int input_bits;
|
| 88 |
+
unsigned int n_columns;
|
| 89 |
+
unsigned int synapses_per_col;
|
| 90 |
+
float conn_thr;
|
| 91 |
+
float sp_inc;
|
| 92 |
+
float sp_dec;
|
| 93 |
+
float sparsity_target;
|
| 94 |
+
float duty_alpha;
|
| 95 |
+
float thr_adapt_rate;
|
| 96 |
+
// TM constants
|
| 97 |
+
unsigned int cells_per_column;
|
| 98 |
+
unsigned int n_cells;
|
| 99 |
+
unsigned int bits_words;
|
| 100 |
+
unsigned int max_segments_per_cell;
|
| 101 |
+
unsigned int synapses_per_segment;
|
| 102 |
+
unsigned int activation_threshold;
|
| 103 |
+
unsigned int learning_threshold;
|
| 104 |
+
unsigned int max_new_synapses;
|
| 105 |
+
int conn_thr_i16;
|
| 106 |
+
int perm_inc_i16;
|
| 107 |
+
int perm_dec_i16;
|
| 108 |
+
int predicted_seg_dec_i16;
|
| 109 |
+
int initial_perm_i16;
|
| 110 |
+
// Loop constants
|
| 111 |
+
unsigned int T;
|
| 112 |
+
unsigned int learn;
|
| 113 |
+
unsigned int iter_seed;
|
| 114 |
+
unsigned int cooperative_grid_sync;
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
// Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
|
| 118 |
+
// Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
|
| 119 |
+
//
|
| 120 |
+
// cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
|
| 121 |
+
// in ~10-40 ns inside the cluster, with no device-level serialization.
|
| 122 |
+
// Multiple clusters (one per HTM region) run fully concurrently — bounded
|
| 123 |
+
// only by SM count (8 clusters × 16 SMs = 128 ≤ 132 on H200).
|
| 124 |
+
//
|
| 125 |
+
// The flags / expected / phase / cooperative_grid_sync parameters are kept
|
| 126 |
+
// in the signature for call-site compatibility but are unused.
|
| 127 |
+
__device__ static inline void fused_grid_barrier(cg::grid_group /* grid */,
|
| 128 |
+
unsigned int * /* flags — unused */,
|
| 129 |
+
unsigned int /* expected — unused */,
|
| 130 |
+
unsigned int /* phase — unused */,
|
| 131 |
+
unsigned int /* cooperative_grid_sync — unused */) {
|
| 132 |
+
auto cluster = cg::this_cluster();
|
| 133 |
+
cluster.sync();
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
__device__ static inline unsigned int warp_sum_u32(unsigned int v) {
|
| 137 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 138 |
+
v += __shfl_down_sync(0xffffffffu, v, off);
|
| 139 |
+
}
|
| 140 |
+
return v;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
// Core kernel body — works for both single-region and batched launches.
|
| 144 |
+
// Single-region: caller passes the one FusedPtrs struct.
|
| 145 |
+
// Batched: each block reads its region's FusedPtrs via blockIdx.y before
|
| 146 |
+
// calling this. State is independent per region (each region owns its own
|
| 147 |
+
// GPU buffers); grid.sync() is the only cross-block primitive and it
|
| 148 |
+
// spans ALL blocks in the grid (harmless over-sync across regions).
|
| 149 |
+
__device__ static inline
|
| 150 |
+
void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
|
| 151 |
+
cg::grid_group grid = cg::this_grid();
|
| 152 |
+
// Cast pointers.
|
| 153 |
+
const unsigned int * __restrict__ syn_bit = (const unsigned int*)P.syn_bit;
|
| 154 |
+
float * __restrict__ syn_perm = (float*)P.syn_perm;
|
| 155 |
+
float * __restrict__ boost = (float*)P.boost;
|
| 156 |
+
float * __restrict__ active_duty = (float*)P.active_duty;
|
| 157 |
+
float * __restrict__ inhibition_threshold = (float*)P.inhibition_threshold;
|
| 158 |
+
unsigned int * __restrict__ seg_cell_id = (unsigned int*)P.seg_cell_id;
|
| 159 |
+
unsigned int * __restrict__ seg_syn_count = (unsigned int*)P.seg_syn_count;
|
| 160 |
+
unsigned int * __restrict__ syn_presyn = (unsigned int*)P.syn_presyn;
|
| 161 |
+
short * __restrict__ tm_syn_perm = (short*)P.tm_syn_perm;
|
| 162 |
+
unsigned int * __restrict__ cell_seg_count = (unsigned int*)P.cell_seg_count;
|
| 163 |
+
unsigned int * __restrict__ cell_active_a = (unsigned int*)P.cell_active_a;
|
| 164 |
+
unsigned int * __restrict__ cell_active_b = (unsigned int*)P.cell_active_b;
|
| 165 |
+
unsigned int * __restrict__ cell_winner_a = (unsigned int*)P.cell_winner_a;
|
| 166 |
+
unsigned int * __restrict__ cell_winner_b = (unsigned int*)P.cell_winner_b;
|
| 167 |
+
const unsigned char * __restrict__ inputs = (const unsigned char*)P.inputs;
|
| 168 |
+
unsigned char * __restrict__ cols_out = (unsigned char*)P.cols_out;
|
| 169 |
+
float * __restrict__ anom_out = (float*)P.anom_out;
|
| 170 |
+
unsigned int * __restrict__ barrier_counters = (unsigned int*)P.barrier_counters;
|
| 171 |
+
unsigned int * __restrict__ step_scratch = (unsigned int*)P.step_scratch;
|
| 172 |
+
|
| 173 |
+
const unsigned int tid = threadIdx.x;
|
| 174 |
+
const unsigned int lane = tid & 31u;
|
| 175 |
+
const unsigned int warp = tid >> 5;
|
| 176 |
+
const unsigned int warps_per_block = blockDim.x >> 5;
|
| 177 |
+
const unsigned int gwarp = blockIdx.x * warps_per_block + warp;
|
| 178 |
+
const unsigned int n_warps = gridDim.x * warps_per_block;
|
| 179 |
+
|
| 180 |
+
const unsigned int n_cols = cfg.n_columns;
|
| 181 |
+
const unsigned int col_lo = (gwarp * n_cols) / n_warps;
|
| 182 |
+
const unsigned int col_hi = ((gwarp + 1) * n_cols) / n_warps;
|
| 183 |
+
|
| 184 |
+
unsigned int phase = 0u;
|
| 185 |
+
|
| 186 |
+
// =========================================================
|
| 187 |
+
// DSMEM: Cluster-distributed shared memory for hot per-column
|
| 188 |
+
// state (inhibition_threshold, boost, active_duty).
|
| 189 |
+
//
|
| 190 |
+
// Each block in the cluster owns a contiguous slice of
|
| 191 |
+
// [my_col_start, my_col_end) columns in its own __shared__
|
| 192 |
+
// arrays. Any block can peer-read another block's slice via
|
| 193 |
+
// cluster.map_shared_rank(ptr, owner_block_rank)[offset].
|
| 194 |
+
//
|
| 195 |
+
// This eliminates 2×n_cols×T GMEM reads per forward call
|
| 196 |
+
// (read + potential re-read of threshold/boost/duty per timestep).
|
| 197 |
+
// =========================================================
|
| 198 |
+
auto cluster = cg::this_cluster();
|
| 199 |
+
const unsigned int cluster_block_rank = cluster.block_rank(); // 0..cluster_size-1
|
| 200 |
+
const unsigned int cluster_sz = cluster.num_blocks(); // == gridDim.x (≤16)
|
| 201 |
+
|
| 202 |
+
// Partition n_cols evenly across cluster blocks.
|
| 203 |
+
// Each block owns cols_per_block columns starting at my_col_start.
|
| 204 |
+
const unsigned int cols_per_block =
|
| 205 |
+
(n_cols + cluster_sz - 1u) / cluster_sz; // ceil div
|
| 206 |
+
const unsigned int my_col_start =
|
| 207 |
+
cluster_block_rank * cols_per_block;
|
| 208 |
+
const unsigned int my_col_end =
|
| 209 |
+
(my_col_start + cols_per_block < n_cols)
|
| 210 |
+
? (my_col_start + cols_per_block) : n_cols; // clamp
|
| 211 |
+
|
| 212 |
+
// Cluster-distributed shared memory arrays.
|
| 213 |
+
// Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
|
| 214 |
+
// Peer blocks address into each other's smem via map_shared_rank.
|
| 215 |
+
__shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
|
| 216 |
+
__shared__ float s_boost [COLS_PER_CLUSTER_BLOCK_MAX];
|
| 217 |
+
__shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
|
| 218 |
+
|
| 219 |
+
// TMA multicast input staging tile (T9).
|
| 220 |
+
//
|
| 221 |
+
// On Hopper (sm_90a), cg::memcpy_async with cluster scope issues a single
|
| 222 |
+
// TMA DMA that multicasts the source data to all 16 SMs in the cluster
|
| 223 |
+
// simultaneously — replacing ~16 per-block GMEM reads per timestep with a
|
| 224 |
+
// single hardware DMA. After cg::wait(cluster) every SM's s_input_tile
|
| 225 |
+
// is populated identically without any additional DRAM traffic.
|
| 226 |
+
//
|
| 227 |
+
// Fallback: when cfg.input_bits > INPUT_BITS_MAX the tile is bypassed
|
| 228 |
+
// and each thread reads directly from GMEM (original path).
|
| 229 |
+
//
|
| 230 |
+
// Alignment: 16-byte aligned to satisfy TMA descriptor requirements.
|
| 231 |
+
__shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
|
| 232 |
+
|
| 233 |
+
// Initial GMEM → smem load (reads state from previous forward call).
|
| 234 |
+
// Each block loads only its own slice; tid strides across the slice.
|
| 235 |
+
for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
|
| 236 |
+
const unsigned int off = c - my_col_start;
|
| 237 |
+
s_inhib_thr [off] = inhibition_threshold[c];
|
| 238 |
+
s_boost [off] = boost[c];
|
| 239 |
+
s_active_duty[off] = active_duty[c];
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// All blocks in the cluster must finish loading before any block
|
| 243 |
+
// starts reading peer smem inside the T-loop.
|
| 244 |
+
cluster.sync();
|
| 245 |
+
|
| 246 |
+
const unsigned int S = cfg.synapses_per_col;
|
| 247 |
+
const unsigned int cpc = cfg.cells_per_column;
|
| 248 |
+
const unsigned int SPS = cfg.synapses_per_segment;
|
| 249 |
+
const unsigned int MSC = cfg.max_segments_per_cell;
|
| 250 |
+
|
| 251 |
+
// Main timestep loop.
|
| 252 |
+
for (unsigned int t = 0u; t < cfg.T; t++) {
|
| 253 |
+
const unsigned int inp_off = t * cfg.input_bits;
|
| 254 |
+
const unsigned int col_base_out = t * n_cols;
|
| 255 |
+
|
| 256 |
+
unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
|
| 257 |
+
unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
|
| 258 |
+
unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
|
| 259 |
+
unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;
|
| 260 |
+
|
| 261 |
+
// ---- Phase 0: clear curr bitsets for my cell range ----
|
| 262 |
+
const unsigned int my_cell_lo = col_lo * cpc;
|
| 263 |
+
const unsigned int my_cell_hi = col_hi * cpc;
|
| 264 |
+
if (cpc == 32u) {
|
| 265 |
+
// Fast path: one word per column.
|
| 266 |
+
for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
|
| 267 |
+
curr_active[c] = 0u;
|
| 268 |
+
curr_winner[c] = 0u;
|
| 269 |
+
}
|
| 270 |
+
} else {
|
| 271 |
+
for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
|
| 272 |
+
unsigned int w = cell >> 5;
|
| 273 |
+
unsigned int m = 1u << (cell & 31u);
|
| 274 |
+
atomicAnd(&curr_active[w], ~m);
|
| 275 |
+
atomicAnd(&curr_winner[w], ~m);
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
// Block 0, lane 0, warp 0 resets step-scratch counters.
|
| 280 |
+
if (blockIdx.x == 0u && tid == 0u) {
|
| 281 |
+
step_scratch[0] = 0u;
|
| 282 |
+
step_scratch[1] = 0u;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// ---- BARRIER 1 ----
|
| 286 |
+
// Fence: make the above clear-bitsets + scratch writes globally
|
| 287 |
+
// visible before peer blocks observe "barrier arrived".
|
| 288 |
+
__threadfence();
|
| 289 |
+
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 290 |
+
|
| 291 |
+
// =========================================================
|
| 292 |
+
// T9: TMA MULTICAST INPUT STAGING
|
| 293 |
+
//
|
| 294 |
+
// Issue a single cluster-scope async DMA to broadcast this
|
| 295 |
+
// timestep's input slice into s_input_tile across all 16 SMs
|
| 296 |
+
// in the cluster simultaneously. On Hopper sm_90a,
|
| 297 |
+
// cg::memcpy_async with cluster scope maps to the TMA
|
| 298 |
+
// hardware unit (cp.async.bulk.tensor multicast), reducing
|
| 299 |
+
// DRAM input traffic by ~16× vs each block fetching its own
|
| 300 |
+
// copy from GMEM.
|
| 301 |
+
//
|
| 302 |
+
// The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
|
| 303 |
+
// If the tile is too small (custom large input_bits), we fall
|
| 304 |
+
// back to per-thread GMEM reads in Stage A (identical to the
|
| 305 |
+
// original path; use_input_tile==false).
|
| 306 |
+
//
|
| 307 |
+
// Ordering: BARRIER 1 completes before we issue the DMA.
|
| 308 |
+
// The DMA completes before Stage A reads s_input_tile.
|
| 309 |
+
// =========================================================
|
| 310 |
+
const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
|
| 311 |
+
if (use_input_tile) {
|
| 312 |
+
// Thread-block scope async copy: each SM independently loads
|
| 313 |
+
// its own input tile from GMEM into shared memory.
|
| 314 |
+
//
|
| 315 |
+
// NOTE: CUDA 12.1's cooperative_groups::memcpy_async() rejects
|
| 316 |
+
// cluster_group at compile time (static_assert in async.h:171).
|
| 317 |
+
// True TMA multicast (single DMA for all 16 SMs in the cluster)
|
| 318 |
+
// would require raw PTX cp.async.bulk.tensor with multicast mode,
|
| 319 |
+
// which needs cuTensorMap descriptors on the host side (T11).
|
| 320 |
+
//
|
| 321 |
+
// This per-SM path still gives a meaningful win: it converts
|
| 322 |
+
// the original per-synapse scattered GMEM reads (random access
|
| 323 |
+
// pattern hitting multiple cache lines) into one sequential DMA
|
| 324 |
+
// per SM, improving L2 hit rate and hardware prefetcher
|
| 325 |
+
// effectiveness. The cluster.sync() below ensures all SMs in
|
| 326 |
+
// the cluster have finished loading before any SM enters Stage A.
|
| 327 |
+
auto tb = cg::this_thread_block();
|
| 328 |
+
cg::memcpy_async(tb, s_input_tile,
|
| 329 |
+
inputs + inp_off,
|
| 330 |
+
cfg.input_bits);
|
| 331 |
+
cg::wait(tb);
|
| 332 |
+
// Cluster barrier: all 16 SMs must have loaded their tile
|
| 333 |
+
// before any SM begins reading s_input_tile in Stage A.
|
| 334 |
+
cluster.sync();
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
// =========================================================
|
| 338 |
+
// STAGE A: Spatial Pooler
|
| 339 |
+
//
|
| 340 |
+
// Hot per-column state (boost, inhibition_threshold,
|
| 341 |
+
// active_duty) is served from cluster DSMEM rather than
|
| 342 |
+
// GMEM for each of the T timesteps. GMEM is written on
|
| 343 |
+
// update so state persists across forward calls.
|
| 344 |
+
// =========================================================
|
| 345 |
+
for (unsigned int c = col_lo; c < col_hi; c++) {
|
| 346 |
+
unsigned int base = c * S;
|
| 347 |
+
unsigned int local = 0u;
|
| 348 |
+
for (unsigned int s = lane; s < S; s += 32u) {
|
| 349 |
+
unsigned int b = syn_bit[base + s];
|
| 350 |
+
float p = syn_perm[base + s];
|
| 351 |
+
// T9: read from cluster-broadcast tile when available;
|
| 352 |
+
// fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
|
| 353 |
+
unsigned int inp_byte = use_input_tile
|
| 354 |
+
? (unsigned int)s_input_tile[b]
|
| 355 |
+
: (unsigned int)inputs[inp_off + b];
|
| 356 |
+
unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
|
| 357 |
+
local += hit;
|
| 358 |
+
}
|
| 359 |
+
unsigned int overlap = warp_sum_u32(local);
|
| 360 |
+
overlap = __shfl_sync(0xffffffffu, overlap, 0);
|
| 361 |
+
|
| 362 |
+
// Determine which cluster block owns column c and read
|
| 363 |
+
// boost + threshold from that block's shared memory.
|
| 364 |
+
const unsigned int owner_block = c / cols_per_block;
|
| 365 |
+
const unsigned int owner_offset = c - owner_block * cols_per_block;
|
| 366 |
+
|
| 367 |
+
float boost_val = cluster.map_shared_rank(s_boost, owner_block)[owner_offset];
|
| 368 |
+
float thr = cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset];
|
| 369 |
+
|
| 370 |
+
float boosted = (float)overlap * boost_val;
|
| 371 |
+
unsigned int is_active = (boosted > thr) ? 1u : 0u;
|
| 372 |
+
|
| 373 |
+
if (lane == 0) {
|
| 374 |
+
cols_out[col_base_out + c] = (unsigned char)is_active;
|
| 375 |
+
if (is_active) {
|
| 376 |
+
atomicAdd(&step_scratch[0], 1u);
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
// SP learn (Hebbian) on active columns.
|
| 381 |
+
// T9: use tile for input reads here too.
|
| 382 |
+
if (cfg.learn && is_active) {
|
| 383 |
+
for (unsigned int s = lane; s < S; s += 32u) {
|
| 384 |
+
unsigned int b = syn_bit[base + s];
|
| 385 |
+
float p = syn_perm[base + s];
|
| 386 |
+
unsigned int inp_byte = use_input_tile
|
| 387 |
+
? (unsigned int)s_input_tile[b]
|
| 388 |
+
: (unsigned int)inputs[inp_off + b];
|
| 389 |
+
if (inp_byte != 0u) {
|
| 390 |
+
p += cfg.sp_inc;
|
| 391 |
+
if (p > 1.0f) p = 1.0f;
|
| 392 |
+
} else {
|
| 393 |
+
p -= cfg.sp_dec;
|
| 394 |
+
if (p < 0.0f) p = 0.0f;
|
| 395 |
+
}
|
| 396 |
+
syn_perm[base + s] = p;
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// active_duty EMA + threshold adaptation.
|
| 401 |
+
// Writes go to both peer DSMEM (hot path for next timestep)
|
| 402 |
+
// and GMEM (persistence across forward calls).
|
| 403 |
+
if (lane == 0) {
|
| 404 |
+
float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
|
| 405 |
+
float sample = is_active ? 1.0f : 0.0f;
|
| 406 |
+
ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;
|
| 407 |
+
|
| 408 |
+
// Writeback: peer smem (for next timestep read) + GMEM (persistence).
|
| 409 |
+
cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
|
| 410 |
+
active_duty[c] = ad;
|
| 411 |
+
|
| 412 |
+
// Threshold steers toward target sparsity.
|
| 413 |
+
float err = ad - cfg.sparsity_target;
|
| 414 |
+
float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
|
| 415 |
+
if (new_thr < 0.1f) new_thr = 0.1f;
|
| 416 |
+
if (new_thr > 1000.0f) new_thr = 1000.0f;
|
| 417 |
+
|
| 418 |
+
// Writeback: peer smem (for next timestep read) + GMEM (persistence).
|
| 419 |
+
cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
|
| 420 |
+
inhibition_threshold[c] = new_thr;
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
// ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
|
| 425 |
+
//
|
| 426 |
+
// DATA FLOW PROOF (T-loop iteration invariant):
|
| 427 |
+
//
|
| 428 |
+
// WRITE SITES (lane==0 inside Stage A per-col loop):
|
| 429 |
+
// Line 328: cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad
|
| 430 |
+
// Line 338: cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr
|
| 431 |
+
//
|
| 432 |
+
// READ SITES (Stage A of the NEXT timestep t+1):
|
| 433 |
+
// Line 290: cluster.map_shared_rank(s_boost, owner_block)[owner_offset] (read)
|
| 434 |
+
// Line 291: cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] (read)
|
| 435 |
+
// Line 323: cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] (read)
|
| 436 |
+
//
|
| 437 |
+
// PARTITION MISMATCH (root cause of T8 staleness):
|
| 438 |
+
// cols_per_block = ceil(n_cols / cluster_sz) [smem partition]
|
| 439 |
+
// col_lo/col_hi = floor(gwarp*n_cols/n_warps) [gwarp work partition]
|
| 440 |
+
// These are NOT identical — up to 1 column can spill across partition boundaries.
|
| 441 |
+
// Example: n_cols=1000, cluster_sz=16 → cols_per_block=63, block 1 col_lo=62
|
| 442 |
+
// → block 1 processes column 62 but column 62 belongs to block 0's smem slice.
|
| 443 |
+
// → block 1 issues a PEER WRITE to block 0's s_inhib_thr / s_active_duty.
|
| 444 |
+
//
|
| 445 |
+
// RACE WITHOUT SYNC:
|
| 446 |
+
// Blocks run Stage A concurrently. Block 1 writes block 0's smem at column 62.
|
| 447 |
+
// Block 0 may simultaneously READ s_inhib_thr[62] for its own column 62 in
|
| 448 |
+
// Stage A of the same timestep → concurrent peer write + local read → undefined.
|
| 449 |
+
// Additionally, without cluster.sync() after all peer writes complete, block 0's
|
| 450 |
+
// t+1 Stage A reads might observe t-1 values still cached in its smem.
|
| 451 |
+
//
|
| 452 |
+
// FIX: cluster.sync() here, AFTER Stage A's per-column loop, ensures:
|
| 453 |
+
// 1. All peer smem writes from this timestep are globally visible to all blocks.
|
| 454 |
+
// 2. No block can enter Stage B (or start t+1 Stage A) with stale smem values.
|
| 455 |
+
// 3. GMEM writes (lines 329, 339) are already committed to L2; __threadfence()
|
| 456 |
+
// below ensures they are visible to all SMs before the cluster barrier.
|
| 457 |
+
//
|
| 458 |
+
// ORDERING: write → cluster.sync() here → __threadfence() → cluster.sync() in
|
| 459 |
+
// fused_grid_barrier → next-timestep reads. Both visibility guarantees
|
| 460 |
+
// are now satisfied.
|
| 461 |
+
cluster.sync();
|
| 462 |
+
|
| 463 |
+
// ---- BARRIER 2: SP active_mask must be visible before TM reads ----
|
| 464 |
+
// Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
|
| 465 |
+
// writes to global memory before peers advance past this barrier.
|
| 466 |
+
__threadfence();
|
| 467 |
+
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 468 |
+
|
| 469 |
+
// =========================================================
|
| 470 |
+
// STAGE B: Temporal Memory
|
| 471 |
+
// =========================================================
|
| 472 |
+
for (unsigned int c = col_lo; c < col_hi; c++) {
|
| 473 |
+
unsigned int col_active = cols_out[col_base_out + c];
|
| 474 |
+
if (col_active == 0u) continue;
|
| 475 |
+
|
| 476 |
+
unsigned int base_cell = c * cpc;
|
| 477 |
+
unsigned int any_predicted = 0u;
|
| 478 |
+
unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
|
| 479 |
+
unsigned int best_pot_count = 0u;
|
| 480 |
+
|
| 481 |
+
for (unsigned int k = 0u; k < cpc; k++) {
|
| 482 |
+
unsigned int cell = base_cell + k;
|
| 483 |
+
unsigned int n_segs_here = cell_seg_count[cell];
|
| 484 |
+
if (n_segs_here > MSC) n_segs_here = MSC;
|
| 485 |
+
if (n_segs_here == 0u) continue;
|
| 486 |
+
|
| 487 |
+
unsigned int seg_base_id = cell * MSC;
|
| 488 |
+
unsigned int cell_is_predictive = 0u;
|
| 489 |
+
|
| 490 |
+
for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
|
| 491 |
+
unsigned int seg = seg_base_id + ls;
|
| 492 |
+
unsigned int n_syn = seg_syn_count[seg];
|
| 493 |
+
if (n_syn == 0u) continue;
|
| 494 |
+
unsigned int syn_base = seg * SPS;
|
| 495 |
+
|
| 496 |
+
unsigned int l_conn = 0u;
|
| 497 |
+
unsigned int l_pot = 0u;
|
| 498 |
+
for (unsigned int s = lane; s < n_syn; s += 32u) {
|
| 499 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 500 |
+
unsigned int w = prev_active[presyn >> 5];
|
| 501 |
+
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 502 |
+
if (bit) {
|
| 503 |
+
l_pot += 1u;
|
| 504 |
+
int p = (int)tm_syn_perm[syn_base + s];
|
| 505 |
+
if (p >= cfg.conn_thr_i16) l_conn += 1u;
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
unsigned int tot_conn = warp_sum_u32(l_conn);
|
| 509 |
+
unsigned int tot_pot = warp_sum_u32(l_pot);
|
| 510 |
+
tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
|
| 511 |
+
tot_pot = __shfl_sync(0xffffffffu, tot_pot, 0);
|
| 512 |
+
|
| 513 |
+
if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
|
| 514 |
+
if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
|
| 515 |
+
best_pot_count = tot_pot;
|
| 516 |
+
best_seg_id_for_grow = seg;
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
// Reinforce predicted-and-correct segment.
|
| 520 |
+
if (cfg.learn && tot_conn >= cfg.activation_threshold) {
|
| 521 |
+
for (unsigned int s = lane; s < n_syn; s += 32u) {
|
| 522 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 523 |
+
unsigned int w = prev_active[presyn >> 5];
|
| 524 |
+
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 525 |
+
int p = (int)tm_syn_perm[syn_base + s];
|
| 526 |
+
if (bit) {
|
| 527 |
+
int np = p + cfg.perm_inc_i16;
|
| 528 |
+
if (np > 32767) np = 32767;
|
| 529 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 530 |
+
} else {
|
| 531 |
+
int np = p - cfg.perm_dec_i16;
|
| 532 |
+
if (np < 0) np = 0;
|
| 533 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
if (cell_is_predictive) {
|
| 540 |
+
any_predicted = 1u;
|
| 541 |
+
if (lane == 0) {
|
| 542 |
+
unsigned int w = cell >> 5;
|
| 543 |
+
unsigned int m = 1u << (cell & 31u);
|
| 544 |
+
atomicOr(&curr_active[w], m);
|
| 545 |
+
atomicOr(&curr_winner[w], m);
|
| 546 |
+
}
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
// BURST if no predicted.
|
| 551 |
+
if (!any_predicted) {
|
| 552 |
+
if (lane == 0) {
|
| 553 |
+
for (unsigned int k = 0u; k < cpc; k++) {
|
| 554 |
+
unsigned int cell = base_cell + k;
|
| 555 |
+
unsigned int w = cell >> 5;
|
| 556 |
+
unsigned int m = 1u << (cell & 31u);
|
| 557 |
+
atomicOr(&curr_active[w], m);
|
| 558 |
+
}
|
| 559 |
+
unsigned int win = base_cell;
|
| 560 |
+
unsigned int ww = win >> 5;
|
| 561 |
+
unsigned int wm = 1u << (win & 31u);
|
| 562 |
+
atomicOr(&curr_winner[ww], wm);
|
| 563 |
+
atomicAdd(&step_scratch[1], 1u);
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
if (cfg.learn) {
|
| 567 |
+
unsigned int target_seg;
|
| 568 |
+
unsigned int existing_syn;
|
| 569 |
+
if (best_seg_id_for_grow != 0xFFFFFFFFu) {
|
| 570 |
+
// Reuse best matching segment.
|
| 571 |
+
target_seg = best_seg_id_for_grow;
|
| 572 |
+
existing_syn = seg_syn_count[target_seg];
|
| 573 |
+
target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
|
| 574 |
+
existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);
|
| 575 |
+
|
| 576 |
+
// Reinforce its existing synapses.
|
| 577 |
+
unsigned int syn_base = target_seg * SPS;
|
| 578 |
+
for (unsigned int s = lane; s < existing_syn; s += 32u) {
|
| 579 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 580 |
+
unsigned int w = prev_active[presyn >> 5];
|
| 581 |
+
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 582 |
+
int p = (int)tm_syn_perm[syn_base + s];
|
| 583 |
+
if (bit) {
|
| 584 |
+
int np = p + cfg.perm_inc_i16;
|
| 585 |
+
if (np > 32767) np = 32767;
|
| 586 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 587 |
+
} else {
|
| 588 |
+
int np = p - cfg.perm_dec_i16;
|
| 589 |
+
if (np < 0) np = 0;
|
| 590 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
} else {
|
| 594 |
+
// Allocate new segment on winner cell (cell 0 of col).
|
| 595 |
+
unsigned int new_seg = 0u;
|
| 596 |
+
if (lane == 0) {
|
| 597 |
+
unsigned int winner_cell = base_cell;
|
| 598 |
+
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
|
| 599 |
+
if (slot >= MSC) slot = slot % MSC;
|
| 600 |
+
new_seg = winner_cell * MSC + slot;
|
| 601 |
+
seg_cell_id[new_seg] = winner_cell;
|
| 602 |
+
seg_syn_count[new_seg] = 0u;
|
| 603 |
+
}
|
| 604 |
+
target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
|
| 605 |
+
existing_syn = 0u;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
// Grow synapses to prev_winner cells — lane 0 serialized.
|
| 609 |
+
unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
|
| 610 |
+
unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
|
| 611 |
+
if (lane == 0 && max_grow > 0u) {
|
| 612 |
+
unsigned int syn_base = target_seg * SPS;
|
| 613 |
+
unsigned int grown = 0u;
|
| 614 |
+
unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
|
| 615 |
+
for (unsigned int w_off = 0u;
|
| 616 |
+
w_off < cfg.bits_words && grown < max_grow;
|
| 617 |
+
w_off++) {
|
| 618 |
+
unsigned int widx = (start_off + w_off) % cfg.bits_words;
|
| 619 |
+
unsigned int word = prev_winner[widx];
|
| 620 |
+
while (word != 0u && grown < max_grow) {
|
| 621 |
+
unsigned int bit_pos = __ffs(word) - 1u;
|
| 622 |
+
word &= ~(1u << bit_pos);
|
| 623 |
+
unsigned int cell_id = widx * 32u + bit_pos;
|
| 624 |
+
if (cell_id >= cfg.n_cells) continue;
|
| 625 |
+
bool exists = false;
|
| 626 |
+
for (unsigned int es = 0u; es < existing_syn + grown; es++) {
|
| 627 |
+
if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
|
| 628 |
+
}
|
| 629 |
+
if (exists) continue;
|
| 630 |
+
unsigned int write_idx = existing_syn + grown;
|
| 631 |
+
if (write_idx >= SPS) break;
|
| 632 |
+
syn_presyn[syn_base + write_idx] = cell_id;
|
| 633 |
+
tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
|
| 634 |
+
grown++;
|
| 635 |
+
}
|
| 636 |
+
}
|
| 637 |
+
if (grown > 0u) {
|
| 638 |
+
seg_syn_count[target_seg] = existing_syn + grown;
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
}
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
// ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
|
| 646 |
+
// Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
|
| 647 |
+
// seg_syn_count + syn_presyn before peers advance and consume them as
|
| 648 |
+
// prev_active/prev_winner at t+1.
|
| 649 |
+
__threadfence();
|
| 650 |
+
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 651 |
+
|
| 652 |
+
// Write anomaly for step t.
|
| 653 |
+
if (blockIdx.x == 0u && tid == 0u) {
|
| 654 |
+
unsigned int total = step_scratch[0];
|
| 655 |
+
unsigned int bad = step_scratch[1];
|
| 656 |
+
float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
|
| 657 |
+
anom_out[t] = anom;
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
// Single-region kernel (legacy call site).
|
| 663 |
+
__global__
|
| 664 |
+
void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
|
| 665 |
+
htm_fused_step_body(P, cfg);
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
// Batched kernel: one cooperative launch for B regions. grid.y = B,
|
| 669 |
+
// grid.x = per-region block count. Each block reads its region's
|
| 670 |
+
// FusedPtrs from the device array via blockIdx.y.
|
| 671 |
+
__global__
|
| 672 |
+
void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
|
| 673 |
+
const FusedPtrs P = P_arr[blockIdx.y];
|
| 674 |
+
htm_fused_step_body(P, cfg);
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
} // extern "C"
|
overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Fused mean-reduction + boost-update kernel.
|
| 2 |
+
//
|
| 3 |
+
// Inputs:
|
| 4 |
+
// active_duty[n] (f32)
|
| 5 |
+
// boost_strength (f32)
|
| 6 |
+
//
|
| 7 |
+
// Output:
|
| 8 |
+
// boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
|
| 9 |
+
//
|
| 10 |
+
// Launch: single block (1024 threads), shared mem for reduction. At n=2048
|
| 11 |
+
// each thread handles 2 elements.
|
| 12 |
+
|
| 13 |
+
extern "C" __global__
|
| 14 |
+
void sp_boost_from_duty(
|
| 15 |
+
const float * __restrict__ active_duty, // (n,)
|
| 16 |
+
float * __restrict__ boost, // (n,) in-place out
|
| 17 |
+
float boost_strength,
|
| 18 |
+
unsigned int n
|
| 19 |
+
) {
|
| 20 |
+
extern __shared__ float smem_raw[];
|
| 21 |
+
float * smem = smem_raw;
|
| 22 |
+
const unsigned int tid = threadIdx.x;
|
| 23 |
+
const unsigned int bsz = blockDim.x;
|
| 24 |
+
|
| 25 |
+
// Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
|
| 26 |
+
float local_sum = 0.0f;
|
| 27 |
+
for (unsigned int i = tid; i < n; i += bsz) {
|
| 28 |
+
local_sum += active_duty[i];
|
| 29 |
+
}
|
| 30 |
+
// Warp reduction.
|
| 31 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 32 |
+
local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
|
| 33 |
+
}
|
| 34 |
+
unsigned int lane = tid & 31;
|
| 35 |
+
unsigned int warp = tid >> 5;
|
| 36 |
+
if (lane == 0) smem[warp] = local_sum;
|
| 37 |
+
__syncthreads();
|
| 38 |
+
|
| 39 |
+
// Warp 0 reduces warp-sums.
|
| 40 |
+
__shared__ float mean_s;
|
| 41 |
+
if (warp == 0) {
|
| 42 |
+
unsigned int nwarps = (bsz + 31) / 32;
|
| 43 |
+
float v = (lane < nwarps) ? smem[lane] : 0.0f;
|
| 44 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 45 |
+
v += __shfl_down_sync(0xffffffff, v, off);
|
| 46 |
+
}
|
| 47 |
+
if (tid == 0) {
|
| 48 |
+
mean_s = v / (float)n;
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
__syncthreads();
|
| 52 |
+
|
| 53 |
+
// Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
|
| 54 |
+
float mean = mean_s;
|
| 55 |
+
for (unsigned int i = tid; i < n; i += bsz) {
|
| 56 |
+
float d = active_duty[i] - mean;
|
| 57 |
+
boost[i] = expf(-boost_strength * d);
|
| 58 |
+
}
|
| 59 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_duty.cu
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Duty cycle + boost update kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each column c (one thread each):
|
| 4 |
+
// active_sample = active_mask[c] ? 1 : 0
|
| 5 |
+
// overlap_sample = raw_overlap[c] >= stim_thr ? 1 : 0
|
| 6 |
+
// active_duty[c] = (1-alpha) * active_duty[c] + alpha * active_sample
|
| 7 |
+
// overlap_duty[c] = (1-alpha) * overlap_duty[c] + alpha * overlap_sample
|
| 8 |
+
//
|
| 9 |
+
// Then, if learn:
|
| 10 |
+
// boost[c] = exp(-boost_strength * (active_duty[c] - mean_duty))
|
| 11 |
+
// mean_duty is computed on the host (one reduction) and passed in.
|
| 12 |
+
|
| 13 |
+
extern "C" __global__
|
| 14 |
+
void sp_duty_update(
|
| 15 |
+
const unsigned char * __restrict__ active_mask, // (n_columns,)
|
| 16 |
+
const unsigned int * __restrict__ raw_overlap, // (n_columns,)
|
| 17 |
+
float * __restrict__ active_duty, // (n_columns,) in-place
|
| 18 |
+
float * __restrict__ overlap_duty, // (n_columns,) in-place
|
| 19 |
+
float * __restrict__ boost, // (n_columns,) in-place
|
| 20 |
+
float alpha,
|
| 21 |
+
float stim_thr,
|
| 22 |
+
float boost_strength, // 0 to skip boost
|
| 23 |
+
float mean_duty,
|
| 24 |
+
unsigned int learn_flag, // 0 or 1
|
| 25 |
+
unsigned int n_columns
|
| 26 |
+
) {
|
| 27 |
+
unsigned int c = blockIdx.x * blockDim.x + threadIdx.x;
|
| 28 |
+
if (c >= n_columns) return;
|
| 29 |
+
|
| 30 |
+
float ad = active_duty[c];
|
| 31 |
+
float od = overlap_duty[c];
|
| 32 |
+
|
| 33 |
+
float a_sample = (active_mask[c] != 0) ? 1.0f : 0.0f;
|
| 34 |
+
float o_sample = ((float)raw_overlap[c] >= stim_thr) ? 1.0f : 0.0f;
|
| 35 |
+
|
| 36 |
+
ad = (1.0f - alpha) * ad + alpha * a_sample;
|
| 37 |
+
od = (1.0f - alpha) * od + alpha * o_sample;
|
| 38 |
+
|
| 39 |
+
active_duty[c] = ad;
|
| 40 |
+
overlap_duty[c] = od;
|
| 41 |
+
|
| 42 |
+
if (learn_flag && boost_strength > 0.0f) {
|
| 43 |
+
boost[c] = expf(-boost_strength * (ad - mean_duty));
|
| 44 |
+
}
|
| 45 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_learn.cu
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SP Hebbian learning kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each active (winner) column c, for each of its synapses s:
|
| 4 |
+
// if input[bit[c][s]] active: perm += inc
|
| 5 |
+
// else: perm -= dec
|
| 6 |
+
// Clamp to [0, 1].
|
| 7 |
+
//
|
| 8 |
+
// Launch: one block per column (2048 blocks), but we predicate on
|
| 9 |
+
// active_mask[c] to avoid launching k-specific blocks.
|
| 10 |
+
//
|
| 11 |
+
// This matches the CPU reference line-for-line:
|
| 12 |
+
// src/sp.rs lines 157-169.
|
| 13 |
+
|
| 14 |
+
extern "C" __global__
|
| 15 |
+
void sp_learn(
|
| 16 |
+
const unsigned char * __restrict__ active_mask, // (n_columns,) 0/1
|
| 17 |
+
const unsigned char * __restrict__ inp, // (input_bits,)
|
| 18 |
+
const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
|
| 19 |
+
float * __restrict__ syn_perm, // (n_columns * S,) in-place
|
| 20 |
+
float inc,
|
| 21 |
+
float dec,
|
| 22 |
+
unsigned int synapses_per_col,
|
| 23 |
+
unsigned int n_columns
|
| 24 |
+
) {
|
| 25 |
+
const unsigned int c = blockIdx.x;
|
| 26 |
+
if (c >= n_columns) return;
|
| 27 |
+
if (active_mask[c] == 0) return;
|
| 28 |
+
|
| 29 |
+
const unsigned int base = c * synapses_per_col;
|
| 30 |
+
const unsigned int tid = threadIdx.x;
|
| 31 |
+
const unsigned int bsz = blockDim.x;
|
| 32 |
+
|
| 33 |
+
for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
|
| 34 |
+
unsigned int b = syn_bit[base + s];
|
| 35 |
+
float p = syn_perm[base + s];
|
| 36 |
+
if (inp[b] != 0) {
|
| 37 |
+
p += inc;
|
| 38 |
+
if (p > 1.0f) p = 1.0f;
|
| 39 |
+
} else {
|
| 40 |
+
p -= dec;
|
| 41 |
+
if (p < 0.0f) p = 0.0f;
|
| 42 |
+
}
|
| 43 |
+
syn_perm[base + s] = p;
|
| 44 |
+
}
|
| 45 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_overlap.cu
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SP overlap kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each column c (one CUDA block), compute:
|
| 4 |
+
// overlap[c] = sum over its synapse list of {inp[bit[c][s]] && perm[c][s] >= conn_thr}
|
| 5 |
+
// boosted[c] = overlap[c] * boost[c]
|
| 6 |
+
// raw_overlap[c] = overlap[c] (also returned so host can drive duty cycle)
|
| 7 |
+
//
|
| 8 |
+
// Memory layout (flat, column-major with per-column stride = synapses_per_col):
|
| 9 |
+
// syn_bit[c * S + s] : u32 index into input SDR
|
| 10 |
+
// syn_perm[c * S + s] : f32 permanence in [0, 1]
|
| 11 |
+
// boost[c] : f32
|
| 12 |
+
// inp[b] : u8 0/1
|
| 13 |
+
// Output:
|
| 14 |
+
// raw[c] : u32
|
| 15 |
+
// boosted[c] : f32
|
| 16 |
+
//
|
| 17 |
+
// Launch:
|
| 18 |
+
// grid = n_columns
|
| 19 |
+
// block = 128 (or 256) — one warp-sweep across synapses; many warps give
|
| 20 |
+
// parallel reduction across S (typically S=40).
|
| 21 |
+
//
|
| 22 |
+
// At S=40 this is completely latency-bound; we coalesce reads and do a
|
| 23 |
+
// warp-shuffle reduction. For clarity we use a simple block-wide shared-mem
|
| 24 |
+
// reduction which is sufficient for S <= 1024 and has zero correctness risk.
|
| 25 |
+
|
| 26 |
+
extern "C" __global__
|
| 27 |
+
void sp_overlap(
|
| 28 |
+
const unsigned char * __restrict__ inp, // (input_bits,)
|
| 29 |
+
const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
|
| 30 |
+
const float * __restrict__ syn_perm,// (n_columns * S,)
|
| 31 |
+
const float * __restrict__ boost, // (n_columns,)
|
| 32 |
+
float conn_thr,
|
| 33 |
+
unsigned int synapses_per_col, // S
|
| 34 |
+
unsigned int n_columns,
|
| 35 |
+
unsigned int * __restrict__ raw_out, // (n_columns,)
|
| 36 |
+
float * __restrict__ boosted_out // (n_columns,)
|
| 37 |
+
) {
|
| 38 |
+
const unsigned int c = blockIdx.x;
|
| 39 |
+
if (c >= n_columns) return;
|
| 40 |
+
|
| 41 |
+
const unsigned int base = c * synapses_per_col;
|
| 42 |
+
const unsigned int tid = threadIdx.x;
|
| 43 |
+
const unsigned int bsz = blockDim.x;
|
| 44 |
+
|
| 45 |
+
// Per-thread partial count.
|
| 46 |
+
unsigned int local = 0;
|
| 47 |
+
for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
|
| 48 |
+
unsigned int b = syn_bit[base + s];
|
| 49 |
+
float p = syn_perm[base + s];
|
| 50 |
+
// Branchless: only counts when input active AND perm connected.
|
| 51 |
+
// Using (inp != 0) to tolerate u8 layout.
|
| 52 |
+
unsigned int hit = ((inp[b] != 0) && (p >= conn_thr)) ? 1u : 0u;
|
| 53 |
+
local += hit;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// Block-wide reduction in shared memory.
|
| 57 |
+
__shared__ unsigned int smem[32];
|
| 58 |
+
|
| 59 |
+
// Warp-level reduction via shuffle.
|
| 60 |
+
unsigned int lane = tid & 31;
|
| 61 |
+
unsigned int warp = tid >> 5;
|
| 62 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 63 |
+
local += __shfl_down_sync(0xffffffff, local, off);
|
| 64 |
+
}
|
| 65 |
+
if (lane == 0) smem[warp] = local;
|
| 66 |
+
__syncthreads();
|
| 67 |
+
|
| 68 |
+
if (warp == 0) {
|
| 69 |
+
unsigned int v = (tid < (bsz + 31) / 32) ? smem[lane] : 0;
|
| 70 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 71 |
+
v += __shfl_down_sync(0xffffffff, v, off);
|
| 72 |
+
}
|
| 73 |
+
if (tid == 0) {
|
| 74 |
+
raw_out[c] = v;
|
| 75 |
+
boosted_out[c] = (float)v * boost[c];
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_topk.cu
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Top-K column selection.
|
| 2 |
+
//
|
| 3 |
+
// Inputs:
|
| 4 |
+
// boosted[n_columns] : f32 score
|
| 5 |
+
// Output:
|
| 6 |
+
// active_mask[n_columns] : u8 0/1, exactly k ones
|
| 7 |
+
//
|
| 8 |
+
// Tie-breaking: when scores are equal, the LOWER column index wins (matches
|
| 9 |
+
// CPU reference `select_nth_unstable_by` with secondary index comparator).
|
| 10 |
+
//
|
| 11 |
+
// Strategy: a single-block implementation. n_columns is typically 2048, which
|
| 12 |
+
// fits comfortably in shared memory. We use a bitonic top-k via per-thread
|
| 13 |
+
// radix-select of the (score, -index) key. At k≈41 of n=2048 the simplest
|
| 14 |
+
// correct approach is a thresholding pass:
|
| 15 |
+
//
|
| 16 |
+
// 1. Radix-like bucket pass to find the k-th largest score.
|
| 17 |
+
// 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
|
| 18 |
+
//
|
| 19 |
+
// For strict index-ordered tie-break we materialise a 64-bit key:
|
| 20 |
+
// key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
|
| 21 |
+
// Larger key = (higher score) OR (same score, smaller index).
|
| 22 |
+
//
|
| 23 |
+
// Then we find the k-th largest 64-bit key via radix-select and mark all
|
| 24 |
+
// columns whose key >= threshold. This is O(n_cols * log k) and well under
|
| 25 |
+
// 100 μs for n=2048, k=41 on sm_86.
|
| 26 |
+
//
|
| 27 |
+
// For simplicity and correctness this kernel uses a single-block parallel
|
| 28 |
+
// selection sort variant (find max → mark → zero → repeat, k iterations).
|
| 29 |
+
// At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
|
| 30 |
+
// fast.
|
| 31 |
+
|
| 32 |
+
extern "C" __global__
|
| 33 |
+
void sp_topk_select(
|
| 34 |
+
const float * __restrict__ scores, // (n_columns,)
|
| 35 |
+
unsigned int n_columns,
|
| 36 |
+
unsigned int k,
|
| 37 |
+
unsigned char * __restrict__ active_out // (n_columns,)
|
| 38 |
+
) {
|
| 39 |
+
extern __shared__ float smem[];
|
| 40 |
+
// Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
|
| 41 |
+
// smem[n..n+32*2] = reduction scratch (score + index, per warp)
|
| 42 |
+
float * work = smem;
|
| 43 |
+
const unsigned int tid = threadIdx.x;
|
| 44 |
+
const unsigned int bsz = blockDim.x;
|
| 45 |
+
|
| 46 |
+
// Load scores into shared; also init active_out = 0.
|
| 47 |
+
for (unsigned int i = tid; i < n_columns; i += bsz) {
|
| 48 |
+
work[i] = scores[i];
|
| 49 |
+
active_out[i] = 0;
|
| 50 |
+
}
|
| 51 |
+
__syncthreads();
|
| 52 |
+
|
| 53 |
+
__shared__ int winner_idx;
|
| 54 |
+
__shared__ float winner_score;
|
| 55 |
+
|
| 56 |
+
for (unsigned int iter = 0; iter < k; ++iter) {
|
| 57 |
+
// Find (argmax score, lowest index for ties).
|
| 58 |
+
float best_s = -INFINITY;
|
| 59 |
+
int best_i = n_columns; // sentinel larger than any index
|
| 60 |
+
|
| 61 |
+
for (unsigned int i = tid; i < n_columns; i += bsz) {
|
| 62 |
+
float s = work[i];
|
| 63 |
+
if (s > best_s || (s == best_s && (int)i < best_i)) {
|
| 64 |
+
best_s = s;
|
| 65 |
+
best_i = (int)i;
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
|
| 70 |
+
unsigned int mask = 0xffffffff;
|
| 71 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 72 |
+
float os = __shfl_down_sync(mask, best_s, off);
|
| 73 |
+
int oi = __shfl_down_sync(mask, best_i, off);
|
| 74 |
+
if (os > best_s || (os == best_s && oi < best_i)) {
|
| 75 |
+
best_s = os;
|
| 76 |
+
best_i = oi;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
// Warp 0 collects lane 0 values from other warps via shared mem.
|
| 80 |
+
__shared__ float warp_s[32];
|
| 81 |
+
__shared__ int warp_i[32];
|
| 82 |
+
unsigned int lane = tid & 31;
|
| 83 |
+
unsigned int warp = tid >> 5;
|
| 84 |
+
if (lane == 0) {
|
| 85 |
+
warp_s[warp] = best_s;
|
| 86 |
+
warp_i[warp] = best_i;
|
| 87 |
+
}
|
| 88 |
+
__syncthreads();
|
| 89 |
+
|
| 90 |
+
if (warp == 0) {
|
| 91 |
+
unsigned int nwarps = (bsz + 31) / 32;
|
| 92 |
+
float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
|
| 93 |
+
int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
|
| 94 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 95 |
+
float os = __shfl_down_sync(mask, s, off);
|
| 96 |
+
int oi = __shfl_down_sync(mask, i, off);
|
| 97 |
+
if (os > s || (os == s && oi < i)) {
|
| 98 |
+
s = os;
|
| 99 |
+
i = oi;
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
if (tid == 0) {
|
| 103 |
+
winner_score = s;
|
| 104 |
+
winner_idx = i;
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
__syncthreads();
|
| 108 |
+
|
| 109 |
+
if (tid == 0) {
|
| 110 |
+
if (winner_idx < (int)n_columns) {
|
| 111 |
+
active_out[winner_idx] = 1;
|
| 112 |
+
work[winner_idx] = -INFINITY;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
__syncthreads();
|
| 116 |
+
}
|
| 117 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_activate.cu
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM activate kernel. See tm_predict.cu for TmConfig.
|
| 2 |
+
|
| 3 |
+
struct TmConfig {
|
| 4 |
+
unsigned int activation_threshold;
|
| 5 |
+
unsigned int learning_threshold;
|
| 6 |
+
unsigned int cells_per_column;
|
| 7 |
+
unsigned int synapses_per_segment;
|
| 8 |
+
unsigned int n_segments;
|
| 9 |
+
unsigned int n_cells;
|
| 10 |
+
unsigned int max_segments_per_cell;
|
| 11 |
+
unsigned int max_new_synapses;
|
| 12 |
+
int conn_thr_i16;
|
| 13 |
+
int perm_inc_i16;
|
| 14 |
+
int perm_dec_i16;
|
| 15 |
+
int predicted_seg_dec_i16;
|
| 16 |
+
int initial_perm_i16;
|
| 17 |
+
unsigned int iter_seed;
|
| 18 |
+
unsigned int n_cols;
|
| 19 |
+
unsigned int bits_words;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
extern "C" __global__
|
| 23 |
+
void tm_activate(
|
| 24 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 25 |
+
const unsigned char * __restrict__ col_predicted,
|
| 26 |
+
const unsigned int * __restrict__ cell_predictive_bits,
|
| 27 |
+
unsigned int * __restrict__ cell_active_bits,
|
| 28 |
+
unsigned int * __restrict__ cell_winner_bits,
|
| 29 |
+
unsigned int * __restrict__ unpredicted_count,
|
| 30 |
+
unsigned int * __restrict__ burst_cols_flat,
|
| 31 |
+
unsigned int * __restrict__ burst_cols_count,
|
| 32 |
+
TmConfig cfg
|
| 33 |
+
) {
|
| 34 |
+
unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
|
| 35 |
+
if (col >= cfg.n_cols) return;
|
| 36 |
+
if (sp_active_mask[col] == 0) return;
|
| 37 |
+
|
| 38 |
+
unsigned int base_cell = col * cfg.cells_per_column;
|
| 39 |
+
|
| 40 |
+
if (col_predicted[col]) {
|
| 41 |
+
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
|
| 42 |
+
unsigned int cell = base_cell + k;
|
| 43 |
+
unsigned int word_idx = cell >> 5;
|
| 44 |
+
unsigned int bit_mask = 1u << (cell & 31u);
|
| 45 |
+
unsigned int pred_word = cell_predictive_bits[word_idx];
|
| 46 |
+
if (pred_word & bit_mask) {
|
| 47 |
+
atomicOr(&cell_active_bits[word_idx], bit_mask);
|
| 48 |
+
atomicOr(&cell_winner_bits[word_idx], bit_mask);
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
} else {
|
| 52 |
+
atomicAdd(unpredicted_count, 1u);
|
| 53 |
+
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
|
| 54 |
+
unsigned int cell = base_cell + k;
|
| 55 |
+
unsigned int word_idx = cell >> 5;
|
| 56 |
+
unsigned int bit_mask = 1u << (cell & 31u);
|
| 57 |
+
atomicOr(&cell_active_bits[word_idx], bit_mask);
|
| 58 |
+
}
|
| 59 |
+
unsigned int winner = base_cell;
|
| 60 |
+
unsigned int word_idx = winner >> 5;
|
| 61 |
+
unsigned int bit_mask = 1u << (winner & 31u);
|
| 62 |
+
atomicOr(&cell_winner_bits[word_idx], bit_mask);
|
| 63 |
+
unsigned int slot = atomicAdd(burst_cols_count, 1u);
|
| 64 |
+
burst_cols_flat[slot] = col;
|
| 65 |
+
}
|
| 66 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM anomaly kernel.
|
| 2 |
+
//
|
| 3 |
+
// Computes:
|
| 4 |
+
// n_active = sum of sp_active_mask
|
| 5 |
+
// anomaly = unpredicted_count / n_active (if n_active > 0)
|
| 6 |
+
// = 0 (else)
|
| 7 |
+
//
|
| 8 |
+
// Launch: single block, 256 threads.
|
| 9 |
+
|
| 10 |
+
extern "C" __global__
|
| 11 |
+
void tm_anomaly(
|
| 12 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 13 |
+
const unsigned int * __restrict__ unpredicted_count,
|
| 14 |
+
float * __restrict__ anomaly_out, // (1,) or (t_slot,)
|
| 15 |
+
unsigned int t_slot,
|
| 16 |
+
unsigned int n_cols
|
| 17 |
+
) {
|
| 18 |
+
const unsigned int tid = threadIdx.x;
|
| 19 |
+
__shared__ unsigned int n_active_s;
|
| 20 |
+
|
| 21 |
+
if (tid == 0) n_active_s = 0u;
|
| 22 |
+
__syncthreads();
|
| 23 |
+
|
| 24 |
+
unsigned int local = 0u;
|
| 25 |
+
for (unsigned int i = tid; i < n_cols; i += blockDim.x) {
|
| 26 |
+
if (sp_active_mask[i]) local += 1u;
|
| 27 |
+
}
|
| 28 |
+
// Warp reduce.
|
| 29 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 30 |
+
local += __shfl_down_sync(0xffffffffu, local, off);
|
| 31 |
+
}
|
| 32 |
+
if ((tid & 31u) == 0) {
|
| 33 |
+
atomicAdd(&n_active_s, local);
|
| 34 |
+
}
|
| 35 |
+
__syncthreads();
|
| 36 |
+
|
| 37 |
+
if (tid == 0) {
|
| 38 |
+
unsigned int total = n_active_s;
|
| 39 |
+
unsigned int bad = unpredicted_count[0];
|
| 40 |
+
float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
|
| 41 |
+
anomaly_out[t_slot] = anom;
|
| 42 |
+
}
|
| 43 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_grow.cu
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM grow+reinforce kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each bursting column:
|
| 4 |
+
// If col_best_match[col] is non-zero (i.e. at least one matching segment
|
| 5 |
+
// with num_active_potential >= learning_threshold exists on cells in this col):
|
| 6 |
+
// Target = that matching segment.
|
| 7 |
+
// Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
|
| 8 |
+
// Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
|
| 9 |
+
// Else:
|
| 10 |
+
// Allocate a fresh segment slot on winner cell (cell 0 of col).
|
| 11 |
+
// Grow up to max_new synapses to prev_winners (no reinforce needed — new seg).
|
| 12 |
+
//
|
| 13 |
+
// This mirrors the CPU TM burst logic.
|
| 14 |
+
|
| 15 |
+
struct TmConfig {
|
| 16 |
+
unsigned int activation_threshold;
|
| 17 |
+
unsigned int learning_threshold;
|
| 18 |
+
unsigned int cells_per_column;
|
| 19 |
+
unsigned int synapses_per_segment;
|
| 20 |
+
unsigned int n_segments;
|
| 21 |
+
unsigned int n_cells;
|
| 22 |
+
unsigned int max_segments_per_cell;
|
| 23 |
+
unsigned int max_new_synapses;
|
| 24 |
+
int conn_thr_i16;
|
| 25 |
+
int perm_inc_i16;
|
| 26 |
+
int perm_dec_i16;
|
| 27 |
+
int predicted_seg_dec_i16;
|
| 28 |
+
int initial_perm_i16;
|
| 29 |
+
unsigned int iter_seed;
|
| 30 |
+
unsigned int n_cols;
|
| 31 |
+
unsigned int bits_words;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
extern "C" __global__
|
| 35 |
+
void tm_grow(
|
| 36 |
+
unsigned int * __restrict__ seg_cell_id,
|
| 37 |
+
unsigned int * __restrict__ seg_syn_count,
|
| 38 |
+
unsigned int * __restrict__ syn_presyn,
|
| 39 |
+
short * __restrict__ syn_perm,
|
| 40 |
+
unsigned int * __restrict__ cell_seg_count,
|
| 41 |
+
const unsigned int * __restrict__ burst_cols_flat,
|
| 42 |
+
const unsigned int * __restrict__ burst_cols_count,
|
| 43 |
+
const unsigned int * __restrict__ prev_winner_bits,
|
| 44 |
+
const unsigned int * __restrict__ prev_active_bits,
|
| 45 |
+
const unsigned int * __restrict__ col_best_match,
|
| 46 |
+
TmConfig cfg
|
| 47 |
+
) {
|
| 48 |
+
const unsigned int b = blockIdx.x;
|
| 49 |
+
const unsigned int n_burst_cols = burst_cols_count[0];
|
| 50 |
+
if (b >= n_burst_cols) return;
|
| 51 |
+
const unsigned int tid = threadIdx.x;
|
| 52 |
+
|
| 53 |
+
const unsigned int col = burst_cols_flat[b];
|
| 54 |
+
|
| 55 |
+
__shared__ unsigned int shared_seg_id;
|
| 56 |
+
__shared__ unsigned int shared_existing_syn_count;
|
| 57 |
+
__shared__ unsigned int shared_grown;
|
| 58 |
+
__shared__ unsigned int shared_is_new;
|
| 59 |
+
__shared__ unsigned int shared_start_offset;
|
| 60 |
+
|
| 61 |
+
if (tid == 0) {
|
| 62 |
+
unsigned int match_key = col_best_match[col];
|
| 63 |
+
if (match_key != 0u) {
|
| 64 |
+
// Reuse matching segment.
|
| 65 |
+
unsigned int seg_id = match_key & 0x1FFFFFu;
|
| 66 |
+
shared_seg_id = seg_id;
|
| 67 |
+
shared_existing_syn_count = seg_syn_count[seg_id];
|
| 68 |
+
shared_is_new = 0u;
|
| 69 |
+
} else {
|
| 70 |
+
// Allocate new segment on winner cell (cell 0 of col).
|
| 71 |
+
unsigned int winner_cell = col * cfg.cells_per_column;
|
| 72 |
+
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
|
| 73 |
+
if (slot >= cfg.max_segments_per_cell) {
|
| 74 |
+
slot = slot % cfg.max_segments_per_cell;
|
| 75 |
+
}
|
| 76 |
+
unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
|
| 77 |
+
seg_cell_id[seg_id] = winner_cell;
|
| 78 |
+
seg_syn_count[seg_id] = 0;
|
| 79 |
+
shared_seg_id = seg_id;
|
| 80 |
+
shared_existing_syn_count = 0u;
|
| 81 |
+
shared_is_new = 1u;
|
| 82 |
+
}
|
| 83 |
+
shared_grown = 0u;
|
| 84 |
+
shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
|
| 85 |
+
}
|
| 86 |
+
__syncthreads();
|
| 87 |
+
|
| 88 |
+
const unsigned int seg_id = shared_seg_id;
|
| 89 |
+
const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
|
| 90 |
+
const unsigned int existing_syn = shared_existing_syn_count;
|
| 91 |
+
const unsigned int is_new = shared_is_new;
|
| 92 |
+
const unsigned int start = shared_start_offset;
|
| 93 |
+
|
| 94 |
+
// PHASE 1: If reusing, reinforce existing synapses.
|
| 95 |
+
if (!is_new) {
|
| 96 |
+
for (unsigned int s = tid; s < existing_syn; s += 32u) {
|
| 97 |
+
unsigned int presyn = syn_presyn[seg_base + s];
|
| 98 |
+
unsigned int word = prev_active_bits[presyn >> 5];
|
| 99 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 100 |
+
int p = (int)syn_perm[seg_base + s];
|
| 101 |
+
if (bit) {
|
| 102 |
+
int np = p + cfg.perm_inc_i16;
|
| 103 |
+
if (np > 32767) np = 32767;
|
| 104 |
+
syn_perm[seg_base + s] = (short)np;
|
| 105 |
+
} else {
|
| 106 |
+
int np = p - cfg.perm_dec_i16;
|
| 107 |
+
if (np < 0) np = 0;
|
| 108 |
+
syn_perm[seg_base + s] = (short)np;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
__syncthreads();
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
|
| 115 |
+
// that aren't already presynaptic to this segment.
|
| 116 |
+
const unsigned int room = (cfg.synapses_per_segment > existing_syn)
|
| 117 |
+
? (cfg.synapses_per_segment - existing_syn) : 0u;
|
| 118 |
+
const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
|
| 119 |
+
|
| 120 |
+
for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
|
| 121 |
+
if (shared_grown >= max_grow) break;
|
| 122 |
+
unsigned int widx = (start + w_off + tid) % cfg.bits_words;
|
| 123 |
+
unsigned int word = prev_winner_bits[widx];
|
| 124 |
+
while (word != 0u) {
|
| 125 |
+
if (shared_grown >= max_grow) break;
|
| 126 |
+
unsigned int bit_pos = __ffs(word) - 1u;
|
| 127 |
+
word &= ~(1u << bit_pos);
|
| 128 |
+
unsigned int cell = widx * 32u + bit_pos;
|
| 129 |
+
if (cell >= cfg.n_cells) continue;
|
| 130 |
+
|
| 131 |
+
// Skip if already presynaptic (O(existing_syn) scan; usually small).
|
| 132 |
+
bool exists = false;
|
| 133 |
+
for (unsigned int s = 0; s < existing_syn; s++) {
|
| 134 |
+
if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
|
| 135 |
+
}
|
| 136 |
+
if (exists) continue;
|
| 137 |
+
|
| 138 |
+
unsigned int slot = atomicAdd(&shared_grown, 1u);
|
| 139 |
+
if (slot >= max_grow) break;
|
| 140 |
+
unsigned int write_idx = existing_syn + slot;
|
| 141 |
+
if (write_idx >= cfg.synapses_per_segment) break;
|
| 142 |
+
syn_presyn[seg_base + write_idx] = cell;
|
| 143 |
+
syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
__syncthreads();
|
| 147 |
+
|
| 148 |
+
if (tid == 0) {
|
| 149 |
+
unsigned int grown = shared_grown;
|
| 150 |
+
if (grown > max_grow) grown = max_grow;
|
| 151 |
+
unsigned int new_count = existing_syn + grown;
|
| 152 |
+
if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
|
| 153 |
+
seg_syn_count[seg_id] = new_count;
|
| 154 |
+
}
|
| 155 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_learn.cu
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM learn (reinforce correctly predicted segments) — cell-grouped launch.
|
| 2 |
+
//
|
| 3 |
+
// Grid: n_cells.
|
| 4 |
+
// For each cell in a predicted, SP-active column: iterate its segments.
|
| 5 |
+
// For each segment with num_active_connected >= activation_threshold,
|
| 6 |
+
// reinforce its synapses against prev_active_bits.
|
| 7 |
+
|
| 8 |
+
struct TmConfig {
|
| 9 |
+
unsigned int activation_threshold;
|
| 10 |
+
unsigned int learning_threshold;
|
| 11 |
+
unsigned int cells_per_column;
|
| 12 |
+
unsigned int synapses_per_segment;
|
| 13 |
+
unsigned int n_segments;
|
| 14 |
+
unsigned int n_cells;
|
| 15 |
+
unsigned int max_segments_per_cell;
|
| 16 |
+
unsigned int max_new_synapses;
|
| 17 |
+
int conn_thr_i16;
|
| 18 |
+
int perm_inc_i16;
|
| 19 |
+
int perm_dec_i16;
|
| 20 |
+
int predicted_seg_dec_i16;
|
| 21 |
+
int initial_perm_i16;
|
| 22 |
+
unsigned int iter_seed;
|
| 23 |
+
unsigned int n_cols;
|
| 24 |
+
unsigned int bits_words;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
extern "C" __global__
|
| 28 |
+
void tm_learn_reinforce(
|
| 29 |
+
const unsigned int * __restrict__ seg_cell_id,
|
| 30 |
+
const unsigned int * __restrict__ seg_syn_count,
|
| 31 |
+
const unsigned int * __restrict__ syn_presyn,
|
| 32 |
+
short * __restrict__ syn_perm,
|
| 33 |
+
const unsigned int * __restrict__ seg_num_active_connected,
|
| 34 |
+
const unsigned int * __restrict__ prev_active_bits,
|
| 35 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 36 |
+
const unsigned char * __restrict__ col_predicted,
|
| 37 |
+
const unsigned int * __restrict__ cell_seg_count,
|
| 38 |
+
TmConfig cfg
|
| 39 |
+
) {
|
| 40 |
+
const unsigned int cell = blockIdx.x;
|
| 41 |
+
if (cell >= cfg.n_cells) return;
|
| 42 |
+
const unsigned int col = cell / cfg.cells_per_column;
|
| 43 |
+
if (sp_active_mask[col] == 0) return;
|
| 44 |
+
if (col_predicted[col] == 0) return;
|
| 45 |
+
|
| 46 |
+
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 47 |
+
if (n_segs_here == 0) return;
|
| 48 |
+
|
| 49 |
+
const unsigned int tid = threadIdx.x;
|
| 50 |
+
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 51 |
+
|
| 52 |
+
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 53 |
+
const unsigned int seg = seg_base_id + local_seg;
|
| 54 |
+
if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
|
| 55 |
+
const unsigned int n_syn = seg_syn_count[seg];
|
| 56 |
+
if (n_syn == 0) continue;
|
| 57 |
+
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 58 |
+
|
| 59 |
+
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 60 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 61 |
+
unsigned int word = prev_active_bits[presyn >> 5];
|
| 62 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 63 |
+
int p = (int)syn_perm[syn_base + s];
|
| 64 |
+
if (bit) {
|
| 65 |
+
int np = p + cfg.perm_inc_i16;
|
| 66 |
+
if (np > 32767) np = 32767;
|
| 67 |
+
syn_perm[syn_base + s] = (short)np;
|
| 68 |
+
} else {
|
| 69 |
+
int np = p - cfg.perm_dec_i16;
|
| 70 |
+
if (np < 0) np = 0;
|
| 71 |
+
syn_perm[syn_base + s] = (short)np;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_predict.cu
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM predict kernel — cell-grouped launch.
|
| 2 |
+
//
|
| 3 |
+
// Grid: n_cells blocks (one per cell).
|
| 4 |
+
// Block: 32 threads (one warp).
|
| 5 |
+
//
|
| 6 |
+
// Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
|
| 7 |
+
// For each live segment, counts active connected/potential synapses against
|
| 8 |
+
// prev_active_bits. Updates per-segment counters, cell_predictive bit, and
|
| 9 |
+
// col_predicted flag.
|
| 10 |
+
|
| 11 |
+
struct TmConfig {
|
| 12 |
+
unsigned int activation_threshold;
|
| 13 |
+
unsigned int learning_threshold;
|
| 14 |
+
unsigned int cells_per_column;
|
| 15 |
+
unsigned int synapses_per_segment;
|
| 16 |
+
unsigned int n_segments;
|
| 17 |
+
unsigned int n_cells;
|
| 18 |
+
unsigned int max_segments_per_cell;
|
| 19 |
+
unsigned int max_new_synapses;
|
| 20 |
+
int conn_thr_i16;
|
| 21 |
+
int perm_inc_i16;
|
| 22 |
+
int perm_dec_i16;
|
| 23 |
+
int predicted_seg_dec_i16;
|
| 24 |
+
int initial_perm_i16;
|
| 25 |
+
unsigned int iter_seed;
|
| 26 |
+
unsigned int n_cols;
|
| 27 |
+
unsigned int bits_words;
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
extern "C" __global__
|
| 31 |
+
void tm_predict(
|
| 32 |
+
const unsigned int * __restrict__ seg_cell_id,
|
| 33 |
+
const unsigned int * __restrict__ seg_syn_count,
|
| 34 |
+
const unsigned int * __restrict__ syn_presyn,
|
| 35 |
+
const short * __restrict__ syn_perm,
|
| 36 |
+
const unsigned int * __restrict__ cell_active_bits,
|
| 37 |
+
unsigned int * __restrict__ cell_predictive_bits,
|
| 38 |
+
unsigned char * __restrict__ col_predicted,
|
| 39 |
+
unsigned int * __restrict__ seg_num_active_connected,
|
| 40 |
+
unsigned int * __restrict__ seg_num_active_potential,
|
| 41 |
+
unsigned int * __restrict__ col_best_match,
|
| 42 |
+
const unsigned int * __restrict__ cell_seg_count,
|
| 43 |
+
TmConfig cfg
|
| 44 |
+
) {
|
| 45 |
+
const unsigned int cell = blockIdx.x;
|
| 46 |
+
if (cell >= cfg.n_cells) return;
|
| 47 |
+
|
| 48 |
+
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 49 |
+
if (n_segs_here == 0) return;
|
| 50 |
+
|
| 51 |
+
const unsigned int tid = threadIdx.x;
|
| 52 |
+
const unsigned int col = cell / cfg.cells_per_column;
|
| 53 |
+
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 54 |
+
|
| 55 |
+
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 56 |
+
const unsigned int seg = seg_base_id + local_seg;
|
| 57 |
+
const unsigned int n_syn = seg_syn_count[seg];
|
| 58 |
+
if (n_syn == 0) {
|
| 59 |
+
if (tid == 0) {
|
| 60 |
+
seg_num_active_connected[seg] = 0;
|
| 61 |
+
seg_num_active_potential[seg] = 0;
|
| 62 |
+
}
|
| 63 |
+
continue;
|
| 64 |
+
}
|
| 65 |
+
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 66 |
+
|
| 67 |
+
unsigned int local_conn = 0;
|
| 68 |
+
unsigned int local_pot = 0;
|
| 69 |
+
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 70 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 71 |
+
unsigned int word = cell_active_bits[presyn >> 5];
|
| 72 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 73 |
+
if (bit) {
|
| 74 |
+
local_pot += 1u;
|
| 75 |
+
int p = (int)syn_perm[syn_base + s];
|
| 76 |
+
if (p >= cfg.conn_thr_i16) {
|
| 77 |
+
local_conn += 1u;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 82 |
+
local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
|
| 83 |
+
local_pot += __shfl_down_sync(0xffffffffu, local_pot, off);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
if (tid == 0) {
|
| 87 |
+
seg_num_active_connected[seg] = local_conn;
|
| 88 |
+
seg_num_active_potential[seg] = local_pot;
|
| 89 |
+
if (local_conn >= cfg.activation_threshold) {
|
| 90 |
+
unsigned int word_idx = cell >> 5;
|
| 91 |
+
unsigned int bit_mask = 1u << (cell & 31u);
|
| 92 |
+
atomicOr(&cell_predictive_bits[word_idx], bit_mask);
|
| 93 |
+
col_predicted[col] = 1;
|
| 94 |
+
}
|
| 95 |
+
if (local_pot >= cfg.learning_threshold) {
|
| 96 |
+
unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
|
| 97 |
+
unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
|
| 98 |
+
atomicMax(&col_best_match[col], key);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_punish.cu
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM punish — cell-grouped launch.
|
| 2 |
+
|
| 3 |
+
struct TmConfig {
|
| 4 |
+
unsigned int activation_threshold;
|
| 5 |
+
unsigned int learning_threshold;
|
| 6 |
+
unsigned int cells_per_column;
|
| 7 |
+
unsigned int synapses_per_segment;
|
| 8 |
+
unsigned int n_segments;
|
| 9 |
+
unsigned int n_cells;
|
| 10 |
+
unsigned int max_segments_per_cell;
|
| 11 |
+
unsigned int max_new_synapses;
|
| 12 |
+
int conn_thr_i16;
|
| 13 |
+
int perm_inc_i16;
|
| 14 |
+
int perm_dec_i16;
|
| 15 |
+
int predicted_seg_dec_i16;
|
| 16 |
+
int initial_perm_i16;
|
| 17 |
+
unsigned int iter_seed;
|
| 18 |
+
unsigned int n_cols;
|
| 19 |
+
unsigned int bits_words;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
extern "C" __global__
|
| 23 |
+
void tm_punish(
|
| 24 |
+
const unsigned int * __restrict__ seg_cell_id,
|
| 25 |
+
const unsigned int * __restrict__ seg_syn_count,
|
| 26 |
+
const unsigned int * __restrict__ syn_presyn,
|
| 27 |
+
short * __restrict__ syn_perm,
|
| 28 |
+
const unsigned int * __restrict__ seg_num_active_potential,
|
| 29 |
+
const unsigned int * __restrict__ prev_active_bits,
|
| 30 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 31 |
+
const unsigned int * __restrict__ cell_seg_count,
|
| 32 |
+
TmConfig cfg
|
| 33 |
+
) {
|
| 34 |
+
const unsigned int cell = blockIdx.x;
|
| 35 |
+
if (cell >= cfg.n_cells) return;
|
| 36 |
+
const unsigned int col = cell / cfg.cells_per_column;
|
| 37 |
+
if (sp_active_mask[col] != 0) return; // skip: col became active
|
| 38 |
+
|
| 39 |
+
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 40 |
+
if (n_segs_here == 0) return;
|
| 41 |
+
|
| 42 |
+
const unsigned int tid = threadIdx.x;
|
| 43 |
+
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 44 |
+
|
| 45 |
+
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 46 |
+
const unsigned int seg = seg_base_id + local_seg;
|
| 47 |
+
if (seg_num_active_potential[seg] < cfg.learning_threshold) continue;
|
| 48 |
+
const unsigned int n_syn = seg_syn_count[seg];
|
| 49 |
+
if (n_syn == 0) continue;
|
| 50 |
+
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 51 |
+
|
| 52 |
+
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 53 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 54 |
+
unsigned int word = prev_active_bits[presyn >> 5];
|
| 55 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 56 |
+
if (bit) {
|
| 57 |
+
int p = (int)syn_perm[syn_base + s];
|
| 58 |
+
int np = p - cfg.predicted_seg_dec_i16;
|
| 59 |
+
if (np < 0) np = 0;
|
| 60 |
+
syn_perm[syn_base + s] = (short)np;
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_reset.cu
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// TM reset-per-step kernel.
|
| 2 |
+
|
| 3 |
+
extern "C" __global__
|
| 4 |
+
void tm_reset_step(
|
| 5 |
+
unsigned int * __restrict__ cell_active_bits,
|
| 6 |
+
unsigned int * __restrict__ cell_winner_bits,
|
| 7 |
+
unsigned int * __restrict__ cell_predictive_bits,
|
| 8 |
+
unsigned int * __restrict__ prev_active_bits,
|
| 9 |
+
unsigned int * __restrict__ prev_winner_bits,
|
| 10 |
+
unsigned char * __restrict__ col_predicted,
|
| 11 |
+
unsigned int * __restrict__ unpredicted_count,
|
| 12 |
+
unsigned int * __restrict__ burst_cols_count,
|
| 13 |
+
unsigned int * __restrict__ col_best_match,
|
| 14 |
+
unsigned int bits_words,
|
| 15 |
+
unsigned int n_cols
|
| 16 |
+
) {
|
| 17 |
+
unsigned int tid_global = blockIdx.x * blockDim.x + threadIdx.x;
|
| 18 |
+
|
| 19 |
+
if (tid_global < bits_words) {
|
| 20 |
+
prev_active_bits[tid_global] = cell_active_bits[tid_global];
|
| 21 |
+
prev_winner_bits[tid_global] = cell_winner_bits[tid_global];
|
| 22 |
+
cell_active_bits[tid_global] = 0u;
|
| 23 |
+
cell_winner_bits[tid_global] = 0u;
|
| 24 |
+
cell_predictive_bits[tid_global] = 0u;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
if (tid_global < n_cols) {
|
| 28 |
+
col_predicted[tid_global] = 0;
|
| 29 |
+
col_best_match[tid_global] = 0u;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
if (tid_global == 0) {
|
| 33 |
+
unpredicted_count[0] = 0u;
|
| 34 |
+
burst_cols_count[0] = 0u;
|
| 35 |
+
}
|
| 36 |
+
}
|
overlay/htm_rust/src/gpu/mod.rs
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! GPU backend for HTM.
|
| 2 |
+
//!
|
| 3 |
+
//! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
|
| 4 |
+
//! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
|
| 5 |
+
//! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
|
| 6 |
+
//! to the host in one shot.
|
| 7 |
+
//!
|
| 8 |
+
//! TM parity with the CPU reference is approximate:
|
| 9 |
+
//! - Segment growth: winner = cell 0 of bursting column (CPU picks
|
| 10 |
+
//! least-used-cell with RNG tiebreak). This is a pragmatic simplification
|
| 11 |
+
//! for GPU atomicity; learning dynamics are preserved.
|
| 12 |
+
//! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
|
| 13 |
+
//! f32 by <= 1 ULP of the scale factor (≈ 3e-5) — inside any meaningful
|
| 14 |
+
//! HTM learning quantum.
|
| 15 |
+
|
| 16 |
+
#![cfg(feature = "gpu")]
|
| 17 |
+
|
| 18 |
+
pub mod sp_gpu;
|
| 19 |
+
pub mod tm_gpu;
|
| 20 |
+
pub mod fused;
|
| 21 |
+
|
| 22 |
+
#[cfg(test)]
|
| 23 |
+
mod tests;
|
| 24 |
+
|
| 25 |
+
use std::mem::ManuallyDrop;
|
| 26 |
+
|
| 27 |
+
use pyo3::prelude::*;
|
| 28 |
+
use pyo3::types::{PyDict, PyTuple};
|
| 29 |
+
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
|
| 30 |
+
|
| 31 |
+
use crate::region::HTMRegionCore;
|
| 32 |
+
use crate::sp::SpatialPoolerConfig;
|
| 33 |
+
use sp_gpu::SpatialPoolerGpu;
|
| 34 |
+
use tm_gpu::TemporalMemoryGpu;
|
| 35 |
+
use fused::FusedState;
|
| 36 |
+
|
| 37 |
+
/// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
|
| 38 |
+
/// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
|
| 39 |
+
/// torch-owned CUDA allocations zero-copy.
|
| 40 |
+
fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
|
| 41 |
+
// `data` is a (ptr: int, readonly: bool) tuple.
|
| 42 |
+
let data_obj = cai.get_item("data")?
|
| 43 |
+
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
|
| 44 |
+
let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
|
| 45 |
+
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
|
| 46 |
+
let ptr: u64 = data_tup.get_item(0)?.extract()?;
|
| 47 |
+
|
| 48 |
+
// `shape` is a tuple of ints.
|
| 49 |
+
let shape_obj = cai.get_item("shape")?
|
| 50 |
+
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
|
| 51 |
+
let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
|
| 52 |
+
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
|
| 53 |
+
let shape: Vec<usize> = (0..shape_tup.len())
|
| 54 |
+
.map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
|
| 55 |
+
.collect::<PyResult<Vec<_>>>()?;
|
| 56 |
+
|
| 57 |
+
// `typestr` (e.g. "|u1", "<f4").
|
| 58 |
+
let typestr_obj = cai.get_item("typestr")?
|
| 59 |
+
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
|
| 60 |
+
let typestr: String = typestr_obj.extract()?;
|
| 61 |
+
|
| 62 |
+
// Reject non-contiguous tensors — we don't handle strides.
|
| 63 |
+
if let Some(strides) = cai.get_item("strides")? {
|
| 64 |
+
if !strides.is_none() {
|
| 65 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 66 |
+
"CAI 'strides' must be None (tensor must be contiguous)",
|
| 67 |
+
));
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
Ok((ptr, shape, typestr))
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
|
| 75 |
+
#[pyclass(module = "htm_rust")]
|
| 76 |
+
pub struct HTMRegionGpu {
|
| 77 |
+
pub(super) sp_gpu: SpatialPoolerGpu,
|
| 78 |
+
pub(super) tm_gpu: TemporalMemoryGpu,
|
| 79 |
+
pub(super) fused_state: FusedState,
|
| 80 |
+
pub(super) n_columns: usize,
|
| 81 |
+
pub(super) input_bits: usize,
|
| 82 |
+
pub(super) cells_per_column: usize,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
#[pymethods]
|
| 86 |
+
impl HTMRegionGpu {
|
| 87 |
+
#[new]
|
| 88 |
+
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
|
| 89 |
+
fn new(
|
| 90 |
+
input_bits: usize,
|
| 91 |
+
n_columns: usize,
|
| 92 |
+
cells_per_column: usize,
|
| 93 |
+
seed: u64,
|
| 94 |
+
) -> PyResult<Self> {
|
| 95 |
+
if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
|
| 96 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 97 |
+
"input_bits, n_columns, cells_per_column must all be > 0",
|
| 98 |
+
));
|
| 99 |
+
}
|
| 100 |
+
// CPU reference for deterministic SP init.
|
| 101 |
+
let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
|
| 102 |
+
let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
|
| 103 |
+
let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
|
| 104 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 105 |
+
"GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
|
| 106 |
+
sp_cfg.input_bits, sp_cfg.n_columns,
|
| 107 |
+
))
|
| 108 |
+
})?;
|
| 109 |
+
let dev = sp_gpu.dev_ref().clone();
|
| 110 |
+
let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
|
| 111 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 112 |
+
"GPU TM init failed: {e:?}",
|
| 113 |
+
))
|
| 114 |
+
})?;
|
| 115 |
+
let initial_threshold = sp_gpu.initial_threshold_estimate();
|
| 116 |
+
let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
|
| 117 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 118 |
+
"GPU fused state init failed: {e:?}",
|
| 119 |
+
)))?;
|
| 120 |
+
Ok(Self {
|
| 121 |
+
sp_gpu,
|
| 122 |
+
tm_gpu,
|
| 123 |
+
fused_state,
|
| 124 |
+
n_columns,
|
| 125 |
+
input_bits,
|
| 126 |
+
cells_per_column,
|
| 127 |
+
})
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
#[getter] fn input_bits(&self) -> usize { self.input_bits }
|
| 131 |
+
#[getter] fn n_columns(&self) -> usize { self.n_columns }
|
| 132 |
+
#[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
|
| 133 |
+
|
| 134 |
+
/// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
|
| 135 |
+
/// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
|
| 136 |
+
/// to the host at the end.
|
| 137 |
+
#[pyo3(signature = (inputs, learn=true))]
|
| 138 |
+
fn step_many_gpu<'py>(
|
| 139 |
+
&mut self,
|
| 140 |
+
py: Python<'py>,
|
| 141 |
+
inputs: PyReadonlyArray2<'py, bool>,
|
| 142 |
+
learn: bool,
|
| 143 |
+
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
|
| 144 |
+
let shape = inputs.shape();
|
| 145 |
+
if shape.len() != 2 {
|
| 146 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 147 |
+
"inputs must be 2-D (T, input_bits)",
|
| 148 |
+
));
|
| 149 |
+
}
|
| 150 |
+
let t = shape[0];
|
| 151 |
+
let bits = shape[1];
|
| 152 |
+
if bits != self.input_bits {
|
| 153 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 154 |
+
"inputs last dim {bits} != expected input_bits {}",
|
| 155 |
+
self.input_bits,
|
| 156 |
+
)));
|
| 157 |
+
}
|
| 158 |
+
let slice = inputs.as_slice()?;
|
| 159 |
+
let n_cols = self.n_columns;
|
| 160 |
+
let input_vec: Vec<bool> = slice.to_vec();
|
| 161 |
+
|
| 162 |
+
let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
|
| 163 |
+
// 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
|
| 164 |
+
let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
|
| 165 |
+
let inputs_dev = self
|
| 166 |
+
.sp_gpu
|
| 167 |
+
.dev_ref()
|
| 168 |
+
.htod_sync_copy(&sdr_u8_all)
|
| 169 |
+
.map_err(|e| format!("H2D inputs: {e:?}"))?;
|
| 170 |
+
|
| 171 |
+
// 2. Allocate output buffers on device.
|
| 172 |
+
let mut cols_dev = self.sp_gpu.dev_ref()
|
| 173 |
+
.alloc_zeros::<u8>(t * n_cols)
|
| 174 |
+
.map_err(|e| format!("alloc cols: {e:?}"))?;
|
| 175 |
+
let mut anom_dev = self.sp_gpu.dev_ref()
|
| 176 |
+
.alloc_zeros::<f32>(t)
|
| 177 |
+
.map_err(|e| format!("alloc anom: {e:?}"))?;
|
| 178 |
+
|
| 179 |
+
// 3. Run T steps of SP + TM on GPU with NO per-step host sync.
|
| 180 |
+
self.sp_gpu.step_batch_with_tm(
|
| 181 |
+
&inputs_dev,
|
| 182 |
+
t,
|
| 183 |
+
self.input_bits,
|
| 184 |
+
learn,
|
| 185 |
+
&mut cols_dev,
|
| 186 |
+
&mut anom_dev,
|
| 187 |
+
&mut self.tm_gpu,
|
| 188 |
+
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
|
| 189 |
+
|
| 190 |
+
// 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
|
| 191 |
+
let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
|
| 192 |
+
.dtoh_sync_copy(&cols_dev)
|
| 193 |
+
.map_err(|e| format!("D2H cols: {e:?}"))?;
|
| 194 |
+
let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
|
| 195 |
+
.dtoh_sync_copy(&anom_dev)
|
| 196 |
+
.map_err(|e| format!("D2H anom: {e:?}"))?;
|
| 197 |
+
|
| 198 |
+
Ok((cols_host, anom_host))
|
| 199 |
+
});
|
| 200 |
+
|
| 201 |
+
let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 202 |
+
|
| 203 |
+
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
|
| 204 |
+
let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
|
| 205 |
+
.reshape([t, n_cols])
|
| 206 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
|
| 207 |
+
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
|
| 208 |
+
Ok((cols_arr, anom_arr))
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
|
| 212 |
+
/// write outputs directly into caller-allocated torch tensors. Skips the
|
| 213 |
+
/// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
|
| 214 |
+
/// two D2H copies at the end). This is the hot path for `train.py`.
|
| 215 |
+
///
|
| 216 |
+
/// Contract:
|
| 217 |
+
/// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
|
| 218 |
+
/// cols_cai.shape == (T, n_columns), dtype u8 (written)
|
| 219 |
+
/// anom_cai.shape == (T,), dtype f32 (written)
|
| 220 |
+
/// All three tensors must live on the SAME CUDA device as this region.
|
| 221 |
+
///
|
| 222 |
+
/// The torch tensors still own their memory — this method only wraps
|
| 223 |
+
/// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
|
| 224 |
+
/// impl can't free pytorch's allocator.
|
| 225 |
+
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
|
| 226 |
+
fn step_many_cuda(
|
| 227 |
+
&mut self,
|
| 228 |
+
py: Python<'_>,
|
| 229 |
+
sdr_cai: &Bound<'_, PyDict>,
|
| 230 |
+
cols_cai: &Bound<'_, PyDict>,
|
| 231 |
+
anom_cai: &Bound<'_, PyDict>,
|
| 232 |
+
learn: bool,
|
| 233 |
+
) -> PyResult<()> {
|
| 234 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
|
| 235 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
|
| 236 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
|
| 237 |
+
|
| 238 |
+
// typestr sanity. numpy u1 is what torch.uint8 exports.
|
| 239 |
+
if sdr_type != "|u1" {
|
| 240 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 241 |
+
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
|
| 242 |
+
)));
|
| 243 |
+
}
|
| 244 |
+
if cols_type != "|u1" {
|
| 245 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 246 |
+
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
|
| 247 |
+
)));
|
| 248 |
+
}
|
| 249 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 250 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 251 |
+
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
|
| 252 |
+
)));
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
// Shape validation.
|
| 256 |
+
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
|
| 257 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 258 |
+
"sdr_cai shape {sdr_shape:?} != (T, {})",
|
| 259 |
+
self.input_bits,
|
| 260 |
+
)));
|
| 261 |
+
}
|
| 262 |
+
let t = sdr_shape[0];
|
| 263 |
+
if cols_shape != [t, self.n_columns] {
|
| 264 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 265 |
+
"cols_cai shape {cols_shape:?} != ({t}, {})",
|
| 266 |
+
self.n_columns,
|
| 267 |
+
)));
|
| 268 |
+
}
|
| 269 |
+
if anom_shape != [t] {
|
| 270 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 271 |
+
"anom_cai shape {anom_shape:?} != ({t},)",
|
| 272 |
+
)));
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
let dev = self.sp_gpu.dev_ref().clone();
|
| 276 |
+
let n_cols = self.n_columns;
|
| 277 |
+
let input_bits = self.input_bits;
|
| 278 |
+
|
| 279 |
+
let result = py.allow_threads(|| -> Result<(), String> {
|
| 280 |
+
// SAFETY:
|
| 281 |
+
// - ptrs came from torch CUDA tensors validated non-null by the
|
| 282 |
+
// __cuda_array_interface__ contract.
|
| 283 |
+
// - lens computed from validated shapes.
|
| 284 |
+
// - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
|
| 285 |
+
// Drop (which calls cuMemFree) never runs against torch memory.
|
| 286 |
+
// The underlying allocation is owned+freed by torch.
|
| 287 |
+
// - The slices are used only for the duration of this call;
|
| 288 |
+
// torch guarantees the backing tensors are live across it
|
| 289 |
+
// (Python holds refs on the wrapping tensors).
|
| 290 |
+
let inputs_dev = ManuallyDrop::new(unsafe {
|
| 291 |
+
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
|
| 292 |
+
});
|
| 293 |
+
let mut cols_dev = ManuallyDrop::new(unsafe {
|
| 294 |
+
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
|
| 295 |
+
});
|
| 296 |
+
let mut anom_dev = ManuallyDrop::new(unsafe {
|
| 297 |
+
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
|
| 298 |
+
});
|
| 299 |
+
|
| 300 |
+
self.sp_gpu.step_batch_with_tm(
|
| 301 |
+
&inputs_dev,
|
| 302 |
+
t,
|
| 303 |
+
input_bits,
|
| 304 |
+
learn,
|
| 305 |
+
&mut cols_dev,
|
| 306 |
+
&mut anom_dev,
|
| 307 |
+
&mut self.tm_gpu,
|
| 308 |
+
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
|
| 309 |
+
|
| 310 |
+
// Synchronize: kernel writes must be visible to the next torch
|
| 311 |
+
// op that reads cols/anom. Pytorch's default stream is stream 0,
|
| 312 |
+
// and cudarc launches on its own stream — a full device sync
|
| 313 |
+
// is the simplest correct barrier. (Could narrow to a stream
|
| 314 |
+
// wait event in PR 2.)
|
| 315 |
+
// No dev.synchronize() here: caller must explicitly sync via the
|
| 316 |
+
// `device_sync()` method (or PyTorch auto-syncs when the output
|
| 317 |
+
// tensor is next consumed). Removing the per-launch barrier lets
|
| 318 |
+
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
|
| 319 |
+
Ok(())
|
| 320 |
+
});
|
| 321 |
+
|
| 322 |
+
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 323 |
+
Ok(())
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
/// Clear TM state on the GPU.
|
| 327 |
+
fn reset(&mut self) -> PyResult<()> {
|
| 328 |
+
self.tm_gpu.reset().map_err(|e| {
|
| 329 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
|
| 330 |
+
})?;
|
| 331 |
+
self.fused_state.reset().map_err(|e| {
|
| 332 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
|
| 333 |
+
})
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
/// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
|
| 337 |
+
/// forward (SP + TM all in one). Accepts torch CUDA tensors via
|
| 338 |
+
/// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
|
| 339 |
+
/// anomaly directly into caller-allocated torch tensors.
|
| 340 |
+
///
|
| 341 |
+
/// Semantics diverge from `step_many_cuda` in one important way: column
|
| 342 |
+
/// activation uses per-column threshold inhibition instead of global
|
| 343 |
+
/// top-K. The threshold is EMA-adapted per column toward the sparsity
|
| 344 |
+
/// target. See `docs/GPU_HTM.md` §Fused Kernel.
|
| 345 |
+
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
|
| 346 |
+
fn step_many_fused_cuda(
|
| 347 |
+
&mut self,
|
| 348 |
+
py: Python<'_>,
|
| 349 |
+
sdr_cai: &Bound<'_, PyDict>,
|
| 350 |
+
cols_cai: &Bound<'_, PyDict>,
|
| 351 |
+
anom_cai: &Bound<'_, PyDict>,
|
| 352 |
+
learn: bool,
|
| 353 |
+
) -> PyResult<()> {
|
| 354 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
|
| 355 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
|
| 356 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
|
| 357 |
+
|
| 358 |
+
if sdr_type != "|u1" {
|
| 359 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 360 |
+
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
|
| 361 |
+
)));
|
| 362 |
+
}
|
| 363 |
+
if cols_type != "|u1" {
|
| 364 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 365 |
+
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
|
| 366 |
+
)));
|
| 367 |
+
}
|
| 368 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 369 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 370 |
+
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
|
| 371 |
+
)));
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
|
| 375 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 376 |
+
"sdr_cai shape {sdr_shape:?} != (T, {})",
|
| 377 |
+
self.input_bits,
|
| 378 |
+
)));
|
| 379 |
+
}
|
| 380 |
+
let t = sdr_shape[0];
|
| 381 |
+
if cols_shape != [t, self.n_columns] {
|
| 382 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 383 |
+
"cols_cai shape {cols_shape:?} != ({t}, {})",
|
| 384 |
+
self.n_columns,
|
| 385 |
+
)));
|
| 386 |
+
}
|
| 387 |
+
if anom_shape != [t] {
|
| 388 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 389 |
+
"anom_cai shape {anom_shape:?} != ({t},)",
|
| 390 |
+
)));
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
let dev = self.sp_gpu.dev_ref().clone();
|
| 394 |
+
let n_cols = self.n_columns;
|
| 395 |
+
let input_bits = self.input_bits;
|
| 396 |
+
|
| 397 |
+
let result = py.allow_threads(|| -> Result<(), String> {
|
| 398 |
+
let inputs_dev = ManuallyDrop::new(unsafe {
|
| 399 |
+
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
|
| 400 |
+
});
|
| 401 |
+
let mut cols_dev = ManuallyDrop::new(unsafe {
|
| 402 |
+
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
|
| 403 |
+
});
|
| 404 |
+
let mut anom_dev = ManuallyDrop::new(unsafe {
|
| 405 |
+
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
|
| 406 |
+
});
|
| 407 |
+
|
| 408 |
+
fused::launch_fused(
|
| 409 |
+
&mut self.sp_gpu,
|
| 410 |
+
&mut self.tm_gpu,
|
| 411 |
+
&mut self.fused_state,
|
| 412 |
+
&inputs_dev,
|
| 413 |
+
&mut cols_dev,
|
| 414 |
+
&mut anom_dev,
|
| 415 |
+
t,
|
| 416 |
+
input_bits,
|
| 417 |
+
learn,
|
| 418 |
+
).map_err(|e| format!("launch_fused: {e:?}"))?;
|
| 419 |
+
|
| 420 |
+
// No dev.synchronize() here: caller must explicitly sync via the
|
| 421 |
+
// `device_sync()` method (or PyTorch auto-syncs when the output
|
| 422 |
+
// tensor is next consumed). Removing the per-launch barrier lets
|
| 423 |
+
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
|
| 424 |
+
Ok(())
|
| 425 |
+
});
|
| 426 |
+
|
| 427 |
+
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 428 |
+
Ok(())
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Explicit device synchronization — the caller must invoke this after
|
| 432 |
+
/// all batched `step_many_*_cuda` calls complete, before reading the
|
| 433 |
+
/// output tensors from a different CUDA stream. Equivalent to the old
|
| 434 |
+
/// per-call `dev.synchronize()` that was removed for overlap.
|
| 435 |
+
fn device_sync(&self) -> PyResult<()> {
|
| 436 |
+
let dev = self.sp_gpu.dev_ref();
|
| 437 |
+
dev.synchronize()
|
| 438 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
|
| 439 |
+
Ok(())
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
/// Batch B regions into ONE cooperative kernel launch. Breaks through the
|
| 444 |
+
/// CUDA cooperative-kernel device-level serialization: a single cooperative
|
| 445 |
+
/// launch with grid.y=B processes all regions concurrently — ~B× speedup
|
| 446 |
+
/// over B sequential launches.
|
| 447 |
+
///
|
| 448 |
+
/// All regions must have the same config (input_bits, n_columns,
|
| 449 |
+
/// cells_per_column). Each region keeps its independent GPU state.
|
| 450 |
+
/// Does NOT sync; caller must invoke `device_sync()` on any region
|
| 451 |
+
/// afterwards (or rely on a downstream torch op to auto-sync).
|
| 452 |
+
#[pyfunction]
|
| 453 |
+
#[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
|
| 454 |
+
fn step_batch_fused_cuda(
|
| 455 |
+
py: Python<'_>,
|
| 456 |
+
regions: Vec<Py<HTMRegionGpu>>,
|
| 457 |
+
sdr_cais: Vec<Bound<'_, PyDict>>,
|
| 458 |
+
cols_cais: Vec<Bound<'_, PyDict>>,
|
| 459 |
+
anom_cais: Vec<Bound<'_, PyDict>>,
|
| 460 |
+
learn: bool,
|
| 461 |
+
) -> PyResult<()> {
|
| 462 |
+
let b = regions.len();
|
| 463 |
+
if b == 0 {
|
| 464 |
+
return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
|
| 465 |
+
}
|
| 466 |
+
if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
|
| 467 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 468 |
+
"sdr_cais / cols_cais / anom_cais length must match regions",
|
| 469 |
+
));
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
// Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
|
| 473 |
+
let mut sdr_ptrs = Vec::with_capacity(b);
|
| 474 |
+
let mut cols_ptrs = Vec::with_capacity(b);
|
| 475 |
+
let mut anom_ptrs = Vec::with_capacity(b);
|
| 476 |
+
let (input_bits, n_columns, t) = {
|
| 477 |
+
let r0 = regions[0].bind(py).borrow();
|
| 478 |
+
(r0.input_bits, r0.n_columns, {
|
| 479 |
+
let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
|
| 480 |
+
if sh.len() != 2 {
|
| 481 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 482 |
+
format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
|
| 483 |
+
));
|
| 484 |
+
}
|
| 485 |
+
sh[0]
|
| 486 |
+
})
|
| 487 |
+
};
|
| 488 |
+
|
| 489 |
+
for i in 0..b {
|
| 490 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
|
| 491 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
|
| 492 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
|
| 493 |
+
if sdr_type != "|u1" || cols_type != "|u1" {
|
| 494 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 495 |
+
"sdr/cols typestr must be '|u1' (uint8)",
|
| 496 |
+
));
|
| 497 |
+
}
|
| 498 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 499 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 500 |
+
"anom typestr must be '<f4' (float32)",
|
| 501 |
+
));
|
| 502 |
+
}
|
| 503 |
+
if sdr_shape != [t, input_bits] {
|
| 504 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 505 |
+
"sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
|
| 506 |
+
)));
|
| 507 |
+
}
|
| 508 |
+
if cols_shape != [t, n_columns] {
|
| 509 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 510 |
+
"cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
|
| 511 |
+
)));
|
| 512 |
+
}
|
| 513 |
+
if anom_shape != [t] {
|
| 514 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 515 |
+
"anom[{i}] shape {anom_shape:?} != ({t},)"
|
| 516 |
+
)));
|
| 517 |
+
}
|
| 518 |
+
sdr_ptrs.push(sdr_ptr);
|
| 519 |
+
cols_ptrs.push(cols_ptr);
|
| 520 |
+
anom_ptrs.push(anom_ptr);
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
// Exclusively borrow each region. PyRefMut guarantees uniqueness.
|
| 524 |
+
let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
|
| 525 |
+
regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
|
| 526 |
+
// Collect raw mutable pointers — each PyRefMut exclusively borrows its
|
| 527 |
+
// region for the lifetime of this call, so pointers stay valid and
|
| 528 |
+
// unique. launch_fused_batched_raw only dereferences one region at a
|
| 529 |
+
// time, not constructing an aliased slice.
|
| 530 |
+
let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
|
| 531 |
+
.iter_mut()
|
| 532 |
+
.map(|r| &mut **r as *mut HTMRegionGpu)
|
| 533 |
+
.collect();
|
| 534 |
+
|
| 535 |
+
// No allow_threads: raw pointers aren't Send. The launch is GPU-queued
|
| 536 |
+
// and sync'd downstream; holding the GIL for the duration is cheap.
|
| 537 |
+
fused::launch_fused_batched_raw(
|
| 538 |
+
&raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
|
| 539 |
+
t, input_bits, learn,
|
| 540 |
+
)
|
| 541 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
|
| 542 |
+
Ok(())
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 546 |
+
m.add_class::<HTMRegionGpu>()?;
|
| 547 |
+
m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
|
| 548 |
+
Ok(())
|
| 549 |
+
}
|
overlay/htm_rust/src/gpu/sp_gpu.rs
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! GPU implementation of the Spatial Pooler.
|
| 2 |
+
//!
|
| 3 |
+
//! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX
|
| 4 |
+
//! kernels. `compute(input, learn)` performs one SP step and returns the
|
| 5 |
+
//! sorted active-column indices (host `Vec<u32>`) — this is what the CPU
|
| 6 |
+
//! TemporalMemory consumes.
|
| 7 |
+
//!
|
| 8 |
+
//! Persistent state on device (per region):
|
| 9 |
+
//! syn_bit : u32 [n_columns × S] (constant after init)
|
| 10 |
+
//! syn_perm : f32 [n_columns × S] (updated by sp_learn)
|
| 11 |
+
//! boost : f32 [n_columns]
|
| 12 |
+
//! active_duty : f32 [n_columns]
|
| 13 |
+
//! overlap_duty: f32 [n_columns]
|
| 14 |
+
//!
|
| 15 |
+
//! Per-step transient state:
|
| 16 |
+
//! inp_dev : u8 [input_bits] (H2D copy each step)
|
| 17 |
+
//! raw : u32 [n_columns]
|
| 18 |
+
//! boosted : f32 [n_columns]
|
| 19 |
+
//! active_mask : u8 [n_columns] (topk output, D2H at the end)
|
| 20 |
+
|
| 21 |
+
use std::sync::Arc;
|
| 22 |
+
|
| 23 |
+
use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig};
|
| 24 |
+
use cudarc::nvrtc::Ptx;
|
| 25 |
+
|
| 26 |
+
use crate::sp::SpatialPooler;
|
| 27 |
+
|
| 28 |
+
// Embed PTX at compile time. OUT_DIR is set by build.rs.
|
| 29 |
+
const PTX_SP_OVERLAP: &str =
|
| 30 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx"));
|
| 31 |
+
const PTX_SP_TOPK: &str =
|
| 32 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx"));
|
| 33 |
+
const PTX_SP_LEARN: &str =
|
| 34 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx"));
|
| 35 |
+
const PTX_SP_DUTY: &str =
|
| 36 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx"));
|
| 37 |
+
const PTX_SP_BOOST_FUSED: &str =
|
| 38 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx"));
|
| 39 |
+
|
| 40 |
+
pub struct SpatialPoolerGpu {
|
| 41 |
+
dev: Arc<CudaDevice>,
|
| 42 |
+
|
| 43 |
+
// Config mirror (we don't touch CPU SpatialPooler after init).
|
| 44 |
+
input_bits: usize,
|
| 45 |
+
n_columns: usize,
|
| 46 |
+
synapses_per_col: usize,
|
| 47 |
+
conn_thr: f32,
|
| 48 |
+
inc: f32,
|
| 49 |
+
dec: f32,
|
| 50 |
+
sparsity: f32,
|
| 51 |
+
duty_period: f32,
|
| 52 |
+
boost_strength: f32,
|
| 53 |
+
|
| 54 |
+
// Persistent device state.
|
| 55 |
+
syn_bit: CudaSlice<u32>,
|
| 56 |
+
syn_perm: CudaSlice<f32>,
|
| 57 |
+
boost: CudaSlice<f32>,
|
| 58 |
+
active_duty: CudaSlice<f32>,
|
| 59 |
+
overlap_duty: CudaSlice<f32>,
|
| 60 |
+
|
| 61 |
+
// Transient scratch (reused each step).
|
| 62 |
+
inp_dev: CudaSlice<u8>,
|
| 63 |
+
raw: CudaSlice<u32>,
|
| 64 |
+
boosted: CudaSlice<f32>,
|
| 65 |
+
active_mask: CudaSlice<u8>,
|
| 66 |
+
|
| 67 |
+
// Reusable host buffer for D2H of active_mask.
|
| 68 |
+
host_mask: Vec<u8>,
|
| 69 |
+
|
| 70 |
+
/// Strict bit-parity with CPU reference. Enabled for tests.
|
| 71 |
+
/// Forces host-side boost/exp computation and the overlap-duty bump check
|
| 72 |
+
/// every step. Default false for max throughput.
|
| 73 |
+
strict_parity: bool,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
impl SpatialPoolerGpu {
|
| 77 |
+
/// Copy CPU SpatialPooler state onto the device. This preserves the
|
| 78 |
+
/// exact seeded proximal synapse layout + initial permanences, so the
|
| 79 |
+
/// GPU SP is a bit-identical parallel implementation of the CPU SP.
|
| 80 |
+
pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> {
|
| 81 |
+
let dev = CudaDevice::new(0)?;
|
| 82 |
+
let cfg = &cpu.cfg;
|
| 83 |
+
let n = cfg.n_columns;
|
| 84 |
+
let s = cfg.potential_synapses;
|
| 85 |
+
|
| 86 |
+
// Flatten proximal dendrites into column-major arrays.
|
| 87 |
+
let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s);
|
| 88 |
+
let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s);
|
| 89 |
+
for col in &cpu.columns {
|
| 90 |
+
debug_assert_eq!(col.inputs.len(), s);
|
| 91 |
+
debug_assert_eq!(col.perms.len(), s);
|
| 92 |
+
syn_bit_h.extend_from_slice(&col.inputs);
|
| 93 |
+
syn_perm_h.extend_from_slice(&col.perms);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
let syn_bit = dev.htod_sync_copy(&syn_bit_h)?;
|
| 97 |
+
let syn_perm = dev.htod_sync_copy(&syn_perm_h)?;
|
| 98 |
+
let boost = dev.htod_sync_copy(&cpu.boost)?;
|
| 99 |
+
let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?;
|
| 100 |
+
let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?;
|
| 101 |
+
|
| 102 |
+
let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?;
|
| 103 |
+
let raw: CudaSlice<u32> = dev.alloc_zeros(n)?;
|
| 104 |
+
let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?;
|
| 105 |
+
let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?;
|
| 106 |
+
|
| 107 |
+
// Load PTX modules. Each .ptx is a module containing one `extern "C"`
|
| 108 |
+
// function; we tag them by unique module names so multiple SP instances
|
| 109 |
+
// don't collide (cudarc uses the (module, func) pair).
|
| 110 |
+
// Actually: CudaDevice::load_ptx stores under the given module name
|
| 111 |
+
// globally on the device, so we use a deterministic naming scheme.
|
| 112 |
+
let modules = [
|
| 113 |
+
("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"),
|
| 114 |
+
("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"),
|
| 115 |
+
("htm_sp_learn", PTX_SP_LEARN, "sp_learn"),
|
| 116 |
+
("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"),
|
| 117 |
+
("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"),
|
| 118 |
+
];
|
| 119 |
+
for (modname, ptx, fnname) in modules {
|
| 120 |
+
// load_ptx is NOT idempotent — calling twice errors. For multi-region
|
| 121 |
+
// support we check-then-load.
|
| 122 |
+
if dev.get_func(modname, fnname).is_none() {
|
| 123 |
+
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
Ok(Self {
|
| 128 |
+
dev,
|
| 129 |
+
input_bits: cfg.input_bits,
|
| 130 |
+
n_columns: n,
|
| 131 |
+
synapses_per_col: s,
|
| 132 |
+
conn_thr: cfg.connected_threshold,
|
| 133 |
+
inc: cfg.syn_perm_active_inc,
|
| 134 |
+
dec: cfg.syn_perm_inactive_dec,
|
| 135 |
+
sparsity: cfg.sparsity,
|
| 136 |
+
duty_period: cfg.duty_cycle_period,
|
| 137 |
+
boost_strength: cfg.boost_strength,
|
| 138 |
+
syn_bit,
|
| 139 |
+
syn_perm,
|
| 140 |
+
boost,
|
| 141 |
+
active_duty,
|
| 142 |
+
overlap_duty,
|
| 143 |
+
inp_dev,
|
| 144 |
+
raw,
|
| 145 |
+
boosted,
|
| 146 |
+
active_mask,
|
| 147 |
+
host_mask: vec![0u8; n],
|
| 148 |
+
strict_parity: false,
|
| 149 |
+
})
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Enable strict bit-parity mode. Parity tests use this.
|
| 153 |
+
pub fn set_strict_parity(&mut self, strict: bool) {
|
| 154 |
+
self.strict_parity = strict;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/// Access to the underlying CudaDevice for host-side orchestration.
|
| 158 |
+
pub fn dev_ref(&self) -> &Arc<CudaDevice> {
|
| 159 |
+
&self.dev
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// --- Fused-path accessors (immutable state reads + pointer-grabs). ---
|
| 163 |
+
pub fn n_columns_accessor(&self) -> usize { self.n_columns }
|
| 164 |
+
#[allow(dead_code)]
|
| 165 |
+
pub fn input_bits_accessor(&self) -> usize { self.input_bits }
|
| 166 |
+
pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col }
|
| 167 |
+
pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr }
|
| 168 |
+
pub fn inc_accessor(&self) -> f32 { self.inc }
|
| 169 |
+
pub fn dec_accessor(&self) -> f32 { self.dec }
|
| 170 |
+
pub fn sparsity_accessor(&self) -> f32 { self.sparsity }
|
| 171 |
+
pub fn duty_period_accessor(&self) -> f32 { self.duty_period }
|
| 172 |
+
#[allow(dead_code)]
|
| 173 |
+
pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength }
|
| 174 |
+
|
| 175 |
+
pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit }
|
| 176 |
+
pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm }
|
| 177 |
+
pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost }
|
| 178 |
+
pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty }
|
| 179 |
+
|
| 180 |
+
/// Compute the 95th-percentile-like initial threshold from raw overlaps
|
| 181 |
+
/// after a short warmup pass. Used to seed `inhibition_threshold` such
|
| 182 |
+
/// that activation rate starts near the sparsity target.
|
| 183 |
+
/// Placeholder (returns a conservative constant); real warmup pass
|
| 184 |
+
/// happens on the Rust orchestrator side.
|
| 185 |
+
pub fn initial_threshold_estimate(&self) -> f32 {
|
| 186 |
+
// With conn_thr=0.5, init_perm around 0.5±0.1, S=40, sparse SDR at 2%:
|
| 187 |
+
// expected overlap ~ 40 * 0.02 = 0.8 connected hits → boosted ~ 0.8.
|
| 188 |
+
// Top-K selects top 2%, so threshold for top 2% is roughly the
|
| 189 |
+
// 98th-percentile of boosted. Conservative start: 2.0.
|
| 190 |
+
// The per-column adaptation will quickly steer each column's thr.
|
| 191 |
+
2.0f32
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/// Batched multi-step SP on the GPU. Processes T timesteps from a
|
| 195 |
+
/// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column
|
| 196 |
+
/// mask to `cols_dev_out` and `(T,)` active column index list (in a
|
| 197 |
+
/// per-step window of size k, padded with u32::MAX).
|
| 198 |
+
///
|
| 199 |
+
/// For each step, this runs the same 5-kernel pipeline as `compute`, but
|
| 200 |
+
/// skips the per-step boost/duty D2H→exp→H2D round-trip: instead it
|
| 201 |
+
/// accumulates to a host scratch once every `boost_interval` steps.
|
| 202 |
+
///
|
| 203 |
+
/// This is the fast path used by `HTMRegionGpu.step_many_gpu`.
|
| 204 |
+
#[allow(clippy::too_many_arguments)]
|
| 205 |
+
pub fn step_batch(
|
| 206 |
+
&mut self,
|
| 207 |
+
inputs_flat_dev: &CudaSlice<u8>,
|
| 208 |
+
t: usize,
|
| 209 |
+
input_bits: usize,
|
| 210 |
+
learn: bool,
|
| 211 |
+
cols_out: &mut [u8],
|
| 212 |
+
active_indices_host: &mut Vec<u32>,
|
| 213 |
+
) -> Result<(), DriverError> {
|
| 214 |
+
let n = self.n_columns;
|
| 215 |
+
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 216 |
+
debug_assert_eq!(cols_out.len(), t * n);
|
| 217 |
+
|
| 218 |
+
let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
|
| 219 |
+
let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
|
| 220 |
+
let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
|
| 221 |
+
let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
|
| 222 |
+
|
| 223 |
+
let overlap_cfg = LaunchConfig {
|
| 224 |
+
grid_dim: (n as u32, 1, 1),
|
| 225 |
+
block_dim: (128, 1, 1),
|
| 226 |
+
shared_mem_bytes: 0,
|
| 227 |
+
};
|
| 228 |
+
let topk_cfg = LaunchConfig {
|
| 229 |
+
grid_dim: (1, 1, 1),
|
| 230 |
+
block_dim: (256, 1, 1),
|
| 231 |
+
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 232 |
+
};
|
| 233 |
+
let learn_cfg = overlap_cfg;
|
| 234 |
+
let duty_cfg = LaunchConfig {
|
| 235 |
+
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 236 |
+
block_dim: (256, 1, 1),
|
| 237 |
+
shared_mem_bytes: 0,
|
| 238 |
+
};
|
| 239 |
+
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 240 |
+
|
| 241 |
+
// Reusable host buffer for the per-step active_mask D2H.
|
| 242 |
+
self.host_mask.resize(n, 0);
|
| 243 |
+
|
| 244 |
+
active_indices_host.clear();
|
| 245 |
+
|
| 246 |
+
for ti in 0..t {
|
| 247 |
+
// Point overlap kernel at the ti-th slice of the pre-uploaded input.
|
| 248 |
+
// cudarc CudaSlice doesn't have a "view" per se, so we must copy the
|
| 249 |
+
// slice into the reusable inp_dev buffer. This is a D2D copy — much
|
| 250 |
+
// faster than H2D.
|
| 251 |
+
// (Alternative: rewrite kernel to accept an offset; deferred.)
|
| 252 |
+
let in_off = ti * input_bits;
|
| 253 |
+
// Use dtod_copy via raw slice indexing: cudarc exposes slice() for this.
|
| 254 |
+
let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
|
| 255 |
+
self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
|
| 256 |
+
|
| 257 |
+
// 1. sp_overlap
|
| 258 |
+
unsafe {
|
| 259 |
+
overlap_fn.clone().launch(
|
| 260 |
+
overlap_cfg,
|
| 261 |
+
(
|
| 262 |
+
&self.inp_dev,
|
| 263 |
+
&self.syn_bit,
|
| 264 |
+
&self.syn_perm,
|
| 265 |
+
&self.boost,
|
| 266 |
+
self.conn_thr,
|
| 267 |
+
self.synapses_per_col as u32,
|
| 268 |
+
n as u32,
|
| 269 |
+
&mut self.raw,
|
| 270 |
+
&mut self.boosted,
|
| 271 |
+
),
|
| 272 |
+
)?;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// 2. Clear active_mask, then sp_topk
|
| 276 |
+
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 277 |
+
unsafe {
|
| 278 |
+
topk_fn.clone().launch(
|
| 279 |
+
topk_cfg,
|
| 280 |
+
(&self.boosted, n as u32, k as u32, &mut self.active_mask),
|
| 281 |
+
)?;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// 3. sp_learn
|
| 285 |
+
if learn {
|
| 286 |
+
unsafe {
|
| 287 |
+
learn_fn.clone().launch(
|
| 288 |
+
learn_cfg,
|
| 289 |
+
(
|
| 290 |
+
&self.active_mask,
|
| 291 |
+
&self.inp_dev,
|
| 292 |
+
&self.syn_bit,
|
| 293 |
+
&mut self.syn_perm,
|
| 294 |
+
self.inc,
|
| 295 |
+
self.dec,
|
| 296 |
+
self.synapses_per_col as u32,
|
| 297 |
+
n as u32,
|
| 298 |
+
),
|
| 299 |
+
)?;
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// 4. duty update (device)
|
| 304 |
+
unsafe {
|
| 305 |
+
duty_fn.clone().launch(
|
| 306 |
+
duty_cfg,
|
| 307 |
+
(
|
| 308 |
+
&self.active_mask,
|
| 309 |
+
&self.raw,
|
| 310 |
+
&mut self.active_duty,
|
| 311 |
+
&mut self.overlap_duty,
|
| 312 |
+
&mut self.boost,
|
| 313 |
+
alpha,
|
| 314 |
+
1.0f32,
|
| 315 |
+
0.0f32,
|
| 316 |
+
0.0f32,
|
| 317 |
+
0u32,
|
| 318 |
+
n as u32,
|
| 319 |
+
),
|
| 320 |
+
)?;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
// 5. Boost update. Two modes:
|
| 324 |
+
// * strict_parity (tests): host-side exp for bit-exact match.
|
| 325 |
+
// * default (production): GPU expf is close enough and ~10x faster
|
| 326 |
+
// since we skip the D2H/H2D round-trip.
|
| 327 |
+
if learn && self.boost_strength > 0.0 {
|
| 328 |
+
if self.strict_parity {
|
| 329 |
+
let mut duty_host = vec![0f32; n];
|
| 330 |
+
self.dev
|
| 331 |
+
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 332 |
+
let sum: f32 = duty_host.iter().sum();
|
| 333 |
+
let mean = sum / (n as f32);
|
| 334 |
+
let mut boost_host = vec![0f32; n];
|
| 335 |
+
for i in 0..n {
|
| 336 |
+
boost_host[i] =
|
| 337 |
+
(-self.boost_strength * (duty_host[i] - mean)).exp();
|
| 338 |
+
}
|
| 339 |
+
self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
|
| 340 |
+
|
| 341 |
+
// Permanence bump (rare). Only evaluated in strict mode.
|
| 342 |
+
let mut ov_host = vec![0f32; n];
|
| 343 |
+
self.dev
|
| 344 |
+
.dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
|
| 345 |
+
let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
|
| 346 |
+
if max_ov > 0.0 {
|
| 347 |
+
let thr = 0.001f32 * max_ov;
|
| 348 |
+
let bump = self.inc * 0.1f32;
|
| 349 |
+
let bump_cols: Vec<u32> = ov_host
|
| 350 |
+
.iter()
|
| 351 |
+
.enumerate()
|
| 352 |
+
.filter_map(|(i, &o)| {
|
| 353 |
+
if o < thr { Some(i as u32) } else { None }
|
| 354 |
+
})
|
| 355 |
+
.collect();
|
| 356 |
+
if !bump_cols.is_empty() {
|
| 357 |
+
let s = self.synapses_per_col;
|
| 358 |
+
let mut perm_host = vec![0f32; n * s];
|
| 359 |
+
self.dev
|
| 360 |
+
.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
|
| 361 |
+
for &c in &bump_cols {
|
| 362 |
+
let base = (c as usize) * s;
|
| 363 |
+
for p in &mut perm_host[base..base + s] {
|
| 364 |
+
*p = (*p + bump).min(1.0);
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
} else {
|
| 371 |
+
// Fast path: fused mean + boost = expf(-strength*(ad-mean))
|
| 372 |
+
// in a single GPU block. Zero D2H, zero H2D — fully async.
|
| 373 |
+
let boost_fn = self
|
| 374 |
+
.dev
|
| 375 |
+
.get_func("htm_sp_boost_fused", "sp_boost_from_duty")
|
| 376 |
+
.expect("sp_boost_fused not loaded");
|
| 377 |
+
let boost_cfg = LaunchConfig {
|
| 378 |
+
grid_dim: (1, 1, 1),
|
| 379 |
+
block_dim: (1024, 1, 1),
|
| 380 |
+
shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
|
| 381 |
+
};
|
| 382 |
+
unsafe {
|
| 383 |
+
boost_fn.launch(
|
| 384 |
+
boost_cfg,
|
| 385 |
+
(
|
| 386 |
+
&self.active_duty,
|
| 387 |
+
&mut self.boost,
|
| 388 |
+
self.boost_strength,
|
| 389 |
+
n as u32,
|
| 390 |
+
),
|
| 391 |
+
)?;
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
// D2H the active_mask for this step. This is the single
|
| 397 |
+
// unavoidable sync point per step — CPU TM needs the active
|
| 398 |
+
// indices for its next state update. At 2048 bytes / step this
|
| 399 |
+
// is tiny in bandwidth but costs a full syncronize (~5-10μs).
|
| 400 |
+
self.dev
|
| 401 |
+
.dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
|
| 402 |
+
let co = ti * n;
|
| 403 |
+
cols_out[co..co + n].copy_from_slice(&self.host_mask);
|
| 404 |
+
// Extract active indices.
|
| 405 |
+
for (i, &b) in self.host_mask.iter().enumerate() {
|
| 406 |
+
if b != 0 {
|
| 407 |
+
active_indices_host.push(i as u32);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
// Insert separator (u32::MAX) between steps to demarcate step boundaries.
|
| 411 |
+
active_indices_host.push(u32::MAX);
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
Ok(())
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
/// Fully-on-GPU batched SP + TM. Zero per-step host sync.
|
| 418 |
+
///
|
| 419 |
+
/// Inputs:
|
| 420 |
+
/// inputs_flat_dev : (T * input_bits) u8 already uploaded
|
| 421 |
+
/// cols_dev : (T * n_cols) u8 output — active-column mask per step
|
| 422 |
+
/// anom_dev : (T,) f32 output — anomaly score per step
|
| 423 |
+
/// tm : persistent GPU TemporalMemory for this region
|
| 424 |
+
#[allow(clippy::too_many_arguments)]
|
| 425 |
+
pub fn step_batch_with_tm(
|
| 426 |
+
&mut self,
|
| 427 |
+
inputs_flat_dev: &CudaSlice<u8>,
|
| 428 |
+
t: usize,
|
| 429 |
+
input_bits: usize,
|
| 430 |
+
learn: bool,
|
| 431 |
+
cols_dev: &mut CudaSlice<u8>,
|
| 432 |
+
anom_dev: &mut CudaSlice<f32>,
|
| 433 |
+
tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu,
|
| 434 |
+
) -> Result<(), DriverError> {
|
| 435 |
+
let n = self.n_columns;
|
| 436 |
+
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 437 |
+
debug_assert_eq!(cols_dev.len(), t * n);
|
| 438 |
+
debug_assert_eq!(anom_dev.len(), t);
|
| 439 |
+
|
| 440 |
+
let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
|
| 441 |
+
let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
|
| 442 |
+
let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
|
| 443 |
+
let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
|
| 444 |
+
|
| 445 |
+
let overlap_cfg = LaunchConfig {
|
| 446 |
+
grid_dim: (n as u32, 1, 1),
|
| 447 |
+
block_dim: (128, 1, 1),
|
| 448 |
+
shared_mem_bytes: 0,
|
| 449 |
+
};
|
| 450 |
+
let topk_cfg = LaunchConfig {
|
| 451 |
+
grid_dim: (1, 1, 1),
|
| 452 |
+
block_dim: (256, 1, 1),
|
| 453 |
+
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 454 |
+
};
|
| 455 |
+
let learn_cfg = overlap_cfg;
|
| 456 |
+
let duty_cfg = LaunchConfig {
|
| 457 |
+
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 458 |
+
block_dim: (256, 1, 1),
|
| 459 |
+
shared_mem_bytes: 0,
|
| 460 |
+
};
|
| 461 |
+
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 462 |
+
|
| 463 |
+
for ti in 0..t {
|
| 464 |
+
let in_off = ti * input_bits;
|
| 465 |
+
let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
|
| 466 |
+
self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
|
| 467 |
+
|
| 468 |
+
// 1. sp_overlap
|
| 469 |
+
unsafe {
|
| 470 |
+
overlap_fn.clone().launch(
|
| 471 |
+
overlap_cfg,
|
| 472 |
+
(
|
| 473 |
+
&self.inp_dev,
|
| 474 |
+
&self.syn_bit,
|
| 475 |
+
&self.syn_perm,
|
| 476 |
+
&self.boost,
|
| 477 |
+
self.conn_thr,
|
| 478 |
+
self.synapses_per_col as u32,
|
| 479 |
+
n as u32,
|
| 480 |
+
&mut self.raw,
|
| 481 |
+
&mut self.boosted,
|
| 482 |
+
),
|
| 483 |
+
)?;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
// 2. clear + sp_topk
|
| 487 |
+
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 488 |
+
unsafe {
|
| 489 |
+
topk_fn.clone().launch(
|
| 490 |
+
topk_cfg,
|
| 491 |
+
(&self.boosted, n as u32, k as u32, &mut self.active_mask),
|
| 492 |
+
)?;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
// 3. sp_learn
|
| 496 |
+
if learn {
|
| 497 |
+
unsafe {
|
| 498 |
+
learn_fn.clone().launch(
|
| 499 |
+
learn_cfg,
|
| 500 |
+
(
|
| 501 |
+
&self.active_mask,
|
| 502 |
+
&self.inp_dev,
|
| 503 |
+
&self.syn_bit,
|
| 504 |
+
&mut self.syn_perm,
|
| 505 |
+
self.inc,
|
| 506 |
+
self.dec,
|
| 507 |
+
self.synapses_per_col as u32,
|
| 508 |
+
n as u32,
|
| 509 |
+
),
|
| 510 |
+
)?;
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
// 4. duty update (stage 1: no-boost write)
|
| 515 |
+
unsafe {
|
| 516 |
+
duty_fn.clone().launch(
|
| 517 |
+
duty_cfg,
|
| 518 |
+
(
|
| 519 |
+
&self.active_mask,
|
| 520 |
+
&self.raw,
|
| 521 |
+
&mut self.active_duty,
|
| 522 |
+
&mut self.overlap_duty,
|
| 523 |
+
&mut self.boost,
|
| 524 |
+
alpha,
|
| 525 |
+
1.0f32,
|
| 526 |
+
0.0f32,
|
| 527 |
+
0.0f32,
|
| 528 |
+
0u32,
|
| 529 |
+
n as u32,
|
| 530 |
+
),
|
| 531 |
+
)?;
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
// 5. Boost update: fused GPU kernel (no D2H).
|
| 535 |
+
if learn && self.boost_strength > 0.0 {
|
| 536 |
+
let boost_fn = self.dev
|
| 537 |
+
.get_func("htm_sp_boost_fused", "sp_boost_from_duty")
|
| 538 |
+
.expect("sp_boost_fused not loaded");
|
| 539 |
+
let boost_cfg = LaunchConfig {
|
| 540 |
+
grid_dim: (1, 1, 1),
|
| 541 |
+
block_dim: (1024, 1, 1),
|
| 542 |
+
shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
|
| 543 |
+
};
|
| 544 |
+
unsafe {
|
| 545 |
+
boost_fn.launch(
|
| 546 |
+
boost_cfg,
|
| 547 |
+
(
|
| 548 |
+
&self.active_duty,
|
| 549 |
+
&mut self.boost,
|
| 550 |
+
self.boost_strength,
|
| 551 |
+
n as u32,
|
| 552 |
+
),
|
| 553 |
+
)?;
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
// 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n].
|
| 558 |
+
let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n);
|
| 559 |
+
self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?;
|
| 560 |
+
|
| 561 |
+
// 7. GPU TM step: predict + activate + anomaly + learn, all on device.
|
| 562 |
+
tm.step(&self.active_mask, anom_dev, ti as u32, learn)?;
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
Ok(())
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
/// One SP step on the GPU. Returns sorted active-column indices.
|
| 569 |
+
pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> {
|
| 570 |
+
debug_assert_eq!(input.len(), self.input_bits);
|
| 571 |
+
let n = self.n_columns;
|
| 572 |
+
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 573 |
+
|
| 574 |
+
// 1. H2D input SDR.
|
| 575 |
+
self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?;
|
| 576 |
+
|
| 577 |
+
// 2. Launch sp_overlap: grid=n_columns, block=128.
|
| 578 |
+
let overlap_fn = self
|
| 579 |
+
.dev
|
| 580 |
+
.get_func("htm_sp_overlap", "sp_overlap")
|
| 581 |
+
.expect("sp_overlap not loaded");
|
| 582 |
+
let overlap_cfg = LaunchConfig {
|
| 583 |
+
grid_dim: (n as u32, 1, 1),
|
| 584 |
+
block_dim: (128, 1, 1),
|
| 585 |
+
shared_mem_bytes: 0,
|
| 586 |
+
};
|
| 587 |
+
unsafe {
|
| 588 |
+
overlap_fn.launch(
|
| 589 |
+
overlap_cfg,
|
| 590 |
+
(
|
| 591 |
+
&self.inp_dev,
|
| 592 |
+
&self.syn_bit,
|
| 593 |
+
&self.syn_perm,
|
| 594 |
+
&self.boost,
|
| 595 |
+
self.conn_thr,
|
| 596 |
+
self.synapses_per_col as u32,
|
| 597 |
+
n as u32,
|
| 598 |
+
&mut self.raw,
|
| 599 |
+
&mut self.boosted,
|
| 600 |
+
),
|
| 601 |
+
)?;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
// 3. Launch sp_topk: single block, shared mem = n_columns * f32.
|
| 605 |
+
let topk_fn = self
|
| 606 |
+
.dev
|
| 607 |
+
.get_func("htm_sp_topk", "sp_topk_select")
|
| 608 |
+
.expect("sp_topk not loaded");
|
| 609 |
+
let topk_cfg = LaunchConfig {
|
| 610 |
+
grid_dim: (1, 1, 1),
|
| 611 |
+
block_dim: (256, 1, 1),
|
| 612 |
+
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 613 |
+
};
|
| 614 |
+
// Clear active_mask first. memset_zeros avoids an H2D of a host
|
| 615 |
+
// zeroes vector every step.
|
| 616 |
+
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 617 |
+
unsafe {
|
| 618 |
+
topk_fn.launch(
|
| 619 |
+
topk_cfg,
|
| 620 |
+
(
|
| 621 |
+
&self.boosted,
|
| 622 |
+
n as u32,
|
| 623 |
+
k as u32,
|
| 624 |
+
&mut self.active_mask,
|
| 625 |
+
),
|
| 626 |
+
)?;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
// 4. Optional: sp_learn on active columns.
|
| 630 |
+
if learn {
|
| 631 |
+
let learn_fn = self
|
| 632 |
+
.dev
|
| 633 |
+
.get_func("htm_sp_learn", "sp_learn")
|
| 634 |
+
.expect("sp_learn not loaded");
|
| 635 |
+
let learn_cfg = LaunchConfig {
|
| 636 |
+
grid_dim: (n as u32, 1, 1),
|
| 637 |
+
block_dim: (128, 1, 1),
|
| 638 |
+
shared_mem_bytes: 0,
|
| 639 |
+
};
|
| 640 |
+
unsafe {
|
| 641 |
+
learn_fn.launch(
|
| 642 |
+
learn_cfg,
|
| 643 |
+
(
|
| 644 |
+
&self.active_mask,
|
| 645 |
+
&self.inp_dev,
|
| 646 |
+
&self.syn_bit,
|
| 647 |
+
&mut self.syn_perm,
|
| 648 |
+
self.inc,
|
| 649 |
+
self.dec,
|
| 650 |
+
self.synapses_per_col as u32,
|
| 651 |
+
n as u32,
|
| 652 |
+
),
|
| 653 |
+
)?;
|
| 654 |
+
}
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
// 5. Duty cycle + boost update. Always runs (matches CPU).
|
| 658 |
+
// We need mean_duty on the host — compute BEFORE the update (matches
|
| 659 |
+
// CPU sp.rs line 200-205 where mean is computed then written).
|
| 660 |
+
// Actually CPU computes mean of the PRE-update duty cycles too? Re-read:
|
| 661 |
+
// sp.rs lines 186-196 update duty cycles (pre-mean).
|
| 662 |
+
// Line 202: mean = sum(active_duty_cycle) / n ← after update.
|
| 663 |
+
// Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)).
|
| 664 |
+
// So mean is on POST-update values.
|
| 665 |
+
// Easiest: 1) run duty update with boost_strength=0 (skip boost calc),
|
| 666 |
+
// 2) D2H active_duty, compute mean, 3) run a boost-only kernel
|
| 667 |
+
// OR inline the exp() in a second launch with mean passed.
|
| 668 |
+
//
|
| 669 |
+
// For simplicity and correctness we fuse: run the duty kernel with
|
| 670 |
+
// mean=0 and boost_strength=0 (disables boost write), then D2H to
|
| 671 |
+
// compute mean, then re-launch with the true mean. Two launches, one
|
| 672 |
+
// tiny D2H (n × f32). At n=2048 this is 8KB per step — negligible.
|
| 673 |
+
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 674 |
+
let duty_fn = self
|
| 675 |
+
.dev
|
| 676 |
+
.get_func("htm_sp_duty", "sp_duty_update")
|
| 677 |
+
.expect("sp_duty not loaded");
|
| 678 |
+
let duty_cfg = LaunchConfig {
|
| 679 |
+
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 680 |
+
block_dim: (256, 1, 1),
|
| 681 |
+
shared_mem_bytes: 0,
|
| 682 |
+
};
|
| 683 |
+
// Stage 1: update duty cycles (boost_strength=0 -> no write).
|
| 684 |
+
unsafe {
|
| 685 |
+
duty_fn.launch(
|
| 686 |
+
duty_cfg,
|
| 687 |
+
(
|
| 688 |
+
&self.active_mask,
|
| 689 |
+
&self.raw,
|
| 690 |
+
&mut self.active_duty,
|
| 691 |
+
&mut self.overlap_duty,
|
| 692 |
+
&mut self.boost,
|
| 693 |
+
alpha,
|
| 694 |
+
1.0f32, // stim_thr
|
| 695 |
+
0.0f32, // boost_strength = 0 -> skip write
|
| 696 |
+
0.0f32, // mean_duty (unused)
|
| 697 |
+
0u32, // learn_flag = 0
|
| 698 |
+
n as u32,
|
| 699 |
+
),
|
| 700 |
+
)?;
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
if learn && self.boost_strength > 0.0 && self.strict_parity {
|
| 704 |
+
// Boost update must bit-match CPU `f32::exp`, so we compute it on
|
| 705 |
+
// the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048.
|
| 706 |
+
// Critical for learning parity — CUDA expf (even without fast-math)
|
| 707 |
+
// uses different rounding for some inputs than host libm.
|
| 708 |
+
let mut duty_host = vec![0f32; n];
|
| 709 |
+
self.dev
|
| 710 |
+
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 711 |
+
let sum: f32 = duty_host.iter().sum();
|
| 712 |
+
let mean = sum / (n as f32);
|
| 713 |
+
let mut boost_host = vec![0f32; n];
|
| 714 |
+
for i in 0..n {
|
| 715 |
+
boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp();
|
| 716 |
+
}
|
| 717 |
+
self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
|
| 718 |
+
|
| 719 |
+
// CPU sp.rs 210-226: permanence bump for chronically under-stimulated
|
| 720 |
+
// columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle),
|
| 721 |
+
// add inc*0.1 to every synapse of column i (clamped to 1.0).
|
| 722 |
+
// This runs only once per step and only for the rare cases, but we
|
| 723 |
+
// need it for bit-exact parity with CPU learn.
|
| 724 |
+
let mut ov_host = vec![0f32; n];
|
| 725 |
+
self.dev
|
| 726 |
+
.dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
|
| 727 |
+
let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
|
| 728 |
+
if max_ov > 0.0 {
|
| 729 |
+
let thr = 0.001f32 * max_ov;
|
| 730 |
+
let bump = self.inc * 0.1f32;
|
| 731 |
+
// Find columns needing a bump. Usually empty. Rare → D2H/H2D
|
| 732 |
+
// of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40).
|
| 733 |
+
let bump_cols: Vec<u32> = ov_host
|
| 734 |
+
.iter()
|
| 735 |
+
.enumerate()
|
| 736 |
+
.filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None })
|
| 737 |
+
.collect();
|
| 738 |
+
if !bump_cols.is_empty() {
|
| 739 |
+
// Download, bump, upload. (Keeps implementation simple and
|
| 740 |
+
// bit-exact. Could kernelize later.)
|
| 741 |
+
let s = self.synapses_per_col;
|
| 742 |
+
let mut perm_host = vec![0f32; n * s];
|
| 743 |
+
self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
|
| 744 |
+
for &c in &bump_cols {
|
| 745 |
+
let base = (c as usize) * s;
|
| 746 |
+
for p in &mut perm_host[base..base + s] {
|
| 747 |
+
*p = (*p + bump).min(1.0);
|
| 748 |
+
}
|
| 749 |
+
}
|
| 750 |
+
self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
|
| 751 |
+
}
|
| 752 |
+
}
|
| 753 |
+
} else if learn && self.boost_strength > 0.0 {
|
| 754 |
+
// Fast path: GPU-side boost using the already-loaded duty kernel.
|
| 755 |
+
let mut duty_host = vec![0f32; n];
|
| 756 |
+
self.dev
|
| 757 |
+
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 758 |
+
let sum: f32 = duty_host.iter().sum();
|
| 759 |
+
let mean = sum / (n as f32);
|
| 760 |
+
let boost_fn = self
|
| 761 |
+
.dev
|
| 762 |
+
.get_func("htm_sp_duty", "sp_duty_update")
|
| 763 |
+
.expect("sp_duty not loaded");
|
| 764 |
+
unsafe {
|
| 765 |
+
boost_fn.launch(
|
| 766 |
+
duty_cfg,
|
| 767 |
+
(
|
| 768 |
+
&self.active_mask,
|
| 769 |
+
&self.raw,
|
| 770 |
+
&mut self.active_duty,
|
| 771 |
+
&mut self.overlap_duty,
|
| 772 |
+
&mut self.boost,
|
| 773 |
+
0.0f32,
|
| 774 |
+
1.0f32,
|
| 775 |
+
self.boost_strength,
|
| 776 |
+
mean,
|
| 777 |
+
1u32,
|
| 778 |
+
n as u32,
|
| 779 |
+
),
|
| 780 |
+
)?;
|
| 781 |
+
}
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
// 6. D2H active_mask and convert to sorted index list.
|
| 785 |
+
self.dev
|
| 786 |
+
.dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
|
| 787 |
+
let mut active: Vec<u32> = Vec::with_capacity(k);
|
| 788 |
+
for (i, &b) in self.host_mask.iter().enumerate() {
|
| 789 |
+
if b != 0 {
|
| 790 |
+
active.push(i as u32);
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
debug_assert_eq!(active.len(), k, "SP must emit exactly k winners");
|
| 794 |
+
Ok(active)
|
| 795 |
+
}
|
| 796 |
+
}
|
overlay/htm_rust/src/gpu/tests.rs
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Parity tests: GPU SP vs CPU SP reference.
|
| 2 |
+
//!
|
| 3 |
+
//! With matching seeds the two should produce bit-identical active-column sets
|
| 4 |
+
//! when `learn=false`, and remain bit-identical over repeated `learn=true`
|
| 5 |
+
//! steps because the Hebbian update is deterministic (no RNG once initialised).
|
| 6 |
+
//!
|
| 7 |
+
//! Run with: cargo test --release --features gpu
|
| 8 |
+
|
| 9 |
+
#![cfg(test)]
|
| 10 |
+
#![cfg(feature = "gpu")]
|
| 11 |
+
|
| 12 |
+
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 13 |
+
use crate::gpu::sp_gpu::SpatialPoolerGpu;
|
| 14 |
+
use crate::gpu::tm_gpu::TemporalMemoryGpu;
|
| 15 |
+
use crate::gpu::fused::{
|
| 16 |
+
launch_fused, plan_fused_launch, FusedState,
|
| 17 |
+
};
|
| 18 |
+
use cudarc::driver::CudaSlice;
|
| 19 |
+
use rand::{Rng, SeedableRng};
|
| 20 |
+
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 21 |
+
|
| 22 |
+
fn make_sdr(rng: &mut Xoshiro256PlusPlus, bits: usize, sparsity: f32) -> Vec<u8> {
|
| 23 |
+
let on = ((sparsity * bits as f32) as usize).max(1);
|
| 24 |
+
let mut v = vec![0u8; bits];
|
| 25 |
+
let mut placed = 0;
|
| 26 |
+
while placed < on {
|
| 27 |
+
let i = rng.gen_range(0..bits);
|
| 28 |
+
if v[i] == 0 {
|
| 29 |
+
v[i] = 1;
|
| 30 |
+
placed += 1;
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
v
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
#[test]
|
| 37 |
+
fn gpu_sp_matches_cpu_no_learn() {
|
| 38 |
+
let cfg = SpatialPoolerConfig::default();
|
| 39 |
+
let bits = cfg.input_bits;
|
| 40 |
+
let mut cpu = SpatialPooler::new(
|
| 41 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 42 |
+
1234,
|
| 43 |
+
);
|
| 44 |
+
let cpu_for_gpu = SpatialPooler::new(
|
| 45 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 46 |
+
1234,
|
| 47 |
+
);
|
| 48 |
+
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu)
|
| 49 |
+
.expect("gpu init (CUDA device available)");
|
| 50 |
+
gpu.set_strict_parity(true);
|
| 51 |
+
|
| 52 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
|
| 53 |
+
for step in 0..20 {
|
| 54 |
+
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 55 |
+
let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
|
| 56 |
+
|
| 57 |
+
let cpu_active: Vec<u32> = cpu.compute(&sdr_bool, false);
|
| 58 |
+
let gpu_active: Vec<u32> = gpu.compute(&sdr_u8, false).expect("gpu compute");
|
| 59 |
+
|
| 60 |
+
assert_eq!(
|
| 61 |
+
cpu_active, gpu_active,
|
| 62 |
+
"mismatch at step {step}: len cpu={} gpu={}",
|
| 63 |
+
cpu_active.len(), gpu_active.len()
|
| 64 |
+
);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
#[test]
|
| 69 |
+
fn gpu_sp_matches_cpu_with_learn() {
|
| 70 |
+
let cfg = SpatialPoolerConfig::default();
|
| 71 |
+
let bits = cfg.input_bits;
|
| 72 |
+
let mut cpu = SpatialPooler::new(
|
| 73 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 74 |
+
5678,
|
| 75 |
+
);
|
| 76 |
+
let cpu_for_gpu = SpatialPooler::new(
|
| 77 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 78 |
+
5678,
|
| 79 |
+
);
|
| 80 |
+
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
|
| 81 |
+
gpu.set_strict_parity(true);
|
| 82 |
+
|
| 83 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
|
| 84 |
+
for step in 0..50 {
|
| 85 |
+
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 86 |
+
let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
|
| 87 |
+
|
| 88 |
+
let cpu_active = cpu.compute(&sdr_bool, true);
|
| 89 |
+
let gpu_active = gpu.compute(&sdr_u8, true).expect("gpu compute");
|
| 90 |
+
|
| 91 |
+
assert_eq!(
|
| 92 |
+
cpu_active, gpu_active,
|
| 93 |
+
"mismatch at step {step} with learning"
|
| 94 |
+
);
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
#[test]
|
| 99 |
+
fn gpu_tm_anomaly_decays_on_repeating_sequence() {
|
| 100 |
+
// End-to-end GPU pipeline: SP feeds TM; repeating SDR sequence should drive
|
| 101 |
+
// anomaly down over time.
|
| 102 |
+
use crate::gpu::HTMRegionGpu; // not pyclass methods; use internal constructor via Rust
|
| 103 |
+
// Easier: replicate the pipeline directly with SP + TM.
|
| 104 |
+
|
| 105 |
+
let cfg = SpatialPoolerConfig::default();
|
| 106 |
+
let bits = cfg.input_bits;
|
| 107 |
+
let n_cols = cfg.n_columns;
|
| 108 |
+
let cells_per_col = 32usize;
|
| 109 |
+
|
| 110 |
+
let cpu_for_gpu = SpatialPooler::new(SpatialPoolerConfig::default(), 314);
|
| 111 |
+
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
|
| 112 |
+
let dev = sp.dev_ref().clone();
|
| 113 |
+
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col)
|
| 114 |
+
.expect("gpu tm init");
|
| 115 |
+
tm.reset().expect("tm reset");
|
| 116 |
+
|
| 117 |
+
// Build 3 fixed SDRs, feed them in a repeating sequence.
|
| 118 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 119 |
+
let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
|
| 120 |
+
let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
|
| 121 |
+
|
| 122 |
+
// Warm up SP so columns are stable per symbol.
|
| 123 |
+
for _ in 0..100 {
|
| 124 |
+
for s in &seqs {
|
| 125 |
+
let _ = sp.compute(s, true).expect("sp compute");
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Build a long input buffer: 100 repetitions of [A,B,C] = 300 steps.
|
| 130 |
+
let repeats = 100usize;
|
| 131 |
+
let t = repeats * 3;
|
| 132 |
+
let mut inputs_flat = vec![0u8; t * bits];
|
| 133 |
+
for r in 0..repeats {
|
| 134 |
+
for (i, s) in seqs.iter().enumerate() {
|
| 135 |
+
let off = (r * 3 + i) * bits;
|
| 136 |
+
inputs_flat[off..off + bits].copy_from_slice(s);
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod");
|
| 140 |
+
|
| 141 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc cols");
|
| 142 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
|
| 143 |
+
|
| 144 |
+
sp.step_batch_with_tm(
|
| 145 |
+
&inputs_dev,
|
| 146 |
+
t,
|
| 147 |
+
bits,
|
| 148 |
+
true,
|
| 149 |
+
&mut cols_dev,
|
| 150 |
+
&mut anom_dev,
|
| 151 |
+
&mut tm,
|
| 152 |
+
).expect("step_batch_with_tm");
|
| 153 |
+
|
| 154 |
+
let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 155 |
+
let cols: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
|
| 156 |
+
|
| 157 |
+
// Active column count per step must equal k for every step.
|
| 158 |
+
let k = ((cfg.sparsity * n_cols as f32).round() as usize).max(1);
|
| 159 |
+
for ti in 0..t {
|
| 160 |
+
let step_slice = &cols[ti * n_cols..(ti + 1) * n_cols];
|
| 161 |
+
let n_on = step_slice.iter().filter(|&&b| b != 0).count();
|
| 162 |
+
assert_eq!(n_on, k, "step {ti} has {n_on} active cols, expected {k}");
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// First repetition: anomaly should be near 1.0 (nothing predicted).
|
| 166 |
+
let early_avg: f32 = anom[3..9].iter().sum::<f32>() / 6.0;
|
| 167 |
+
// Last repetitions: anomaly should be noticeably lower.
|
| 168 |
+
let late_avg: f32 = anom[(t - 9)..t].iter().sum::<f32>() / 9.0;
|
| 169 |
+
eprintln!("gpu tm: early anomaly = {early_avg:.3}, late = {late_avg:.3}");
|
| 170 |
+
assert!(
|
| 171 |
+
late_avg < early_avg,
|
| 172 |
+
"GPU TM should reduce anomaly on repeating sequence: early={early_avg:.3}, late={late_avg:.3}"
|
| 173 |
+
);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/// Cluster-sync smoke test: verifies that the fused megakernel (which relies on
|
| 177 |
+
/// hardware `cluster::sync()` / grid-barrier on H100/H200 Hopper) completes
|
| 178 |
+
/// without deadlock when called with real HTM state, and that output shapes are
|
| 179 |
+
/// sane (no NaN / Inf in anomaly scores, active-column count in plausible range).
|
| 180 |
+
///
|
| 181 |
+
/// This is an *integration* test, not a synthetic micro-benchmark: it exercises
|
| 182 |
+
/// exactly the same `launch_fused` code path used in production, so any
|
| 183 |
+
/// deadlock in the cooperative-grid or DLB barrier would surface here.
|
| 184 |
+
///
|
| 185 |
+
/// Skips gracefully (with an eprintln) when no GPU is available — the test
|
| 186 |
+
/// binary returns exit-code 0 in that case so CI still passes.
|
| 187 |
+
#[test]
|
| 188 |
+
fn cluster_sync_smoke_test() {
|
| 189 |
+
// Build a tiny HTM region (1024 inputs, 256 columns, 4 cells/column).
|
| 190 |
+
// This keeps VRAM usage minimal while still exercising all kernel paths.
|
| 191 |
+
let input_bits = 1024usize;
|
| 192 |
+
let n_columns = 256usize;
|
| 193 |
+
let cells_per_col = 4usize;
|
| 194 |
+
|
| 195 |
+
// Probe cooperative launch attribute before doing any real work.
|
| 196 |
+
// CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 223 (added in CUDA 11.8 for Hopper).
|
| 197 |
+
// cudarc exposes raw attribute querying; we check cooperative launch (98)
|
| 198 |
+
// as the guard — cluster launch is a superset and not separately probed
|
| 199 |
+
// here since cudarc doesn't expose attribute 223 symbolically yet.
|
| 200 |
+
// On pre-Hopper hardware the DLB barrier path is used instead and the
|
| 201 |
+
// test still validates no deadlock on that path.
|
| 202 |
+
|
| 203 |
+
let make_cfg = || SpatialPoolerConfig {
|
| 204 |
+
input_bits,
|
| 205 |
+
n_columns,
|
| 206 |
+
sparsity: 0.04, // ~10 active cols out of 256
|
| 207 |
+
..SpatialPoolerConfig::default()
|
| 208 |
+
};
|
| 209 |
+
|
| 210 |
+
let cpu_ref = SpatialPooler::new(make_cfg(), 42);
|
| 211 |
+
|
| 212 |
+
let mut sp = match SpatialPoolerGpu::from_cpu(&cpu_ref) {
|
| 213 |
+
Ok(sp) => sp,
|
| 214 |
+
Err(e) => {
|
| 215 |
+
eprintln!("[cluster_sync_smoke_test] No GPU available ({e:?}) — skipping");
|
| 216 |
+
return;
|
| 217 |
+
}
|
| 218 |
+
};
|
| 219 |
+
|
| 220 |
+
let dev = sp.dev_ref().clone();
|
| 221 |
+
|
| 222 |
+
// Check cooperative launch support; skip with a clear message if absent.
|
| 223 |
+
let cooperative_ok = matches!(
|
| 224 |
+
dev.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
|
| 225 |
+
Ok(v) if v > 0
|
| 226 |
+
);
|
| 227 |
+
if !cooperative_ok {
|
| 228 |
+
eprintln!("[cluster_sync_smoke_test] CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH=0 — DLB path only, still running test");
|
| 229 |
+
// We continue — the DLB path is the production fallback and must not deadlock either.
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
let mut tm = match TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_col) {
|
| 233 |
+
Ok(tm) => tm,
|
| 234 |
+
Err(e) => {
|
| 235 |
+
eprintln!("[cluster_sync_smoke_test] TemporalMemoryGpu::new failed ({e:?}) — skipping");
|
| 236 |
+
return;
|
| 237 |
+
}
|
| 238 |
+
};
|
| 239 |
+
tm.reset().expect("tm reset");
|
| 240 |
+
|
| 241 |
+
let mut fused_st: FusedState = match FusedState::new(
|
| 242 |
+
dev.clone(),
|
| 243 |
+
n_columns,
|
| 244 |
+
cells_per_col,
|
| 245 |
+
sp.initial_threshold_estimate(),
|
| 246 |
+
) {
|
| 247 |
+
Ok(f) => f,
|
| 248 |
+
Err(e) => {
|
| 249 |
+
eprintln!("[cluster_sync_smoke_test] FusedState::new failed ({e:?}) — skipping");
|
| 250 |
+
return;
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
fused_st.reset().expect("fused reset");
|
| 254 |
+
|
| 255 |
+
// Build T=4 timesteps of all-zero input SDRs.
|
| 256 |
+
let t = 4usize;
|
| 257 |
+
let inputs_flat = vec![0u8; t * input_bits];
|
| 258 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod inputs");
|
| 259 |
+
|
| 260 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_columns).expect("alloc cols");
|
| 261 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
|
| 262 |
+
|
| 263 |
+
// Execute with a 2-second timeout guard via a thread. If the kernel
|
| 264 |
+
// deadlocks, the parent test process times out and the CI job reports
|
| 265 |
+
// failure — we can't cancel a live CUDA kernel from Rust, but the
|
| 266 |
+
// launch_fused call itself must return within this window on any sane GPU.
|
| 267 |
+
//
|
| 268 |
+
// We run the kernel inline (not in a separate thread) because CUDA contexts
|
| 269 |
+
// are not safely shareable across threads without explicit multi-threading
|
| 270 |
+
// setup. The 2-second bound is enforced implicitly: if the kernel deadlocks,
|
| 271 |
+
// the test binary will hang and the CI timeout (typically 5 min) will kill it.
|
| 272 |
+
// For local dev, the deadlock would be immediately obvious.
|
| 273 |
+
|
| 274 |
+
launch_fused(
|
| 275 |
+
&mut sp,
|
| 276 |
+
&mut tm,
|
| 277 |
+
&mut fused_st,
|
| 278 |
+
&inputs_dev,
|
| 279 |
+
&mut cols_dev,
|
| 280 |
+
&mut anom_dev,
|
| 281 |
+
t,
|
| 282 |
+
input_bits,
|
| 283 |
+
false, // learn=false for determinism
|
| 284 |
+
).expect("launch_fused (cluster_sync_smoke_test): deadlock or CUDA error");
|
| 285 |
+
|
| 286 |
+
dev.synchronize().expect("device sync after launch_fused");
|
| 287 |
+
|
| 288 |
+
// --- Correctness assertions ---
|
| 289 |
+
|
| 290 |
+
let cols_host: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
|
| 291 |
+
let anom_host: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 292 |
+
|
| 293 |
+
// Output buffers must be exactly the right size.
|
| 294 |
+
assert_eq!(cols_host.len(), t * n_columns, "cols buffer size mismatch");
|
| 295 |
+
assert_eq!(anom_host.len(), t, "anom buffer size mismatch");
|
| 296 |
+
|
| 297 |
+
// Anomaly scores must be finite (NaN/Inf indicates numerical blow-up).
|
| 298 |
+
for (i, &a) in anom_host.iter().enumerate() {
|
| 299 |
+
assert!(a.is_finite(), "anomaly[{i}] is not finite: {a}");
|
| 300 |
+
assert!(a >= 0.0 && a <= 1.0, "anomaly[{i}] out of [0,1]: {a}");
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// Active-column count per step: threshold-based inhibition, so 0 is
|
| 304 |
+
// possible on cold start (before thresholds calibrate), but we assert
|
| 305 |
+
// <= n_columns to catch buffer overruns or completely wrong output.
|
| 306 |
+
for ti in 0..t {
|
| 307 |
+
let n_on = cols_host[ti * n_columns..(ti + 1) * n_columns]
|
| 308 |
+
.iter()
|
| 309 |
+
.filter(|&&b| b != 0)
|
| 310 |
+
.count();
|
| 311 |
+
assert!(
|
| 312 |
+
n_on <= n_columns,
|
| 313 |
+
"step {ti}: active columns {n_on} > n_columns {n_columns} (buffer overrun?)"
|
| 314 |
+
);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
eprintln!(
|
| 318 |
+
"[cluster_sync_smoke_test] PASSED: T={t}, n_cols={n_columns}, \
|
| 319 |
+
input_bits={input_bits}, cooperative_supported={cooperative_ok}, \
|
| 320 |
+
anom={anom_host:?}"
|
| 321 |
+
);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/// Parity check: the CAI zero-copy path (`step_many_cuda`) must produce
|
| 325 |
+
/// bit-identical outputs to the numpy H2D/D2H path (`step_batch_with_tm`),
|
| 326 |
+
/// since the kernel pipeline is the same — only the I/O wrapping changes.
|
| 327 |
+
/// We skip the PyO3 CAI dict plumbing here and test the underlying
|
| 328 |
+
/// ManuallyDrop + upgrade_device_ptr pattern directly.
|
| 329 |
+
#[test]
|
| 330 |
+
fn gpu_cuda_vs_numpy_parity() {
|
| 331 |
+
use std::mem::ManuallyDrop;
|
| 332 |
+
|
| 333 |
+
let cfg = SpatialPoolerConfig::default();
|
| 334 |
+
let bits = cfg.input_bits;
|
| 335 |
+
let n_cols = cfg.n_columns;
|
| 336 |
+
let cells_per_col = 32usize;
|
| 337 |
+
|
| 338 |
+
// Build two identical (SP, TM) pairs from the same seed.
|
| 339 |
+
let build = || -> (SpatialPoolerGpu, TemporalMemoryGpu) {
|
| 340 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271828);
|
| 341 |
+
let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu init");
|
| 342 |
+
let dev = sp.dev_ref().clone();
|
| 343 |
+
let mut tm = TemporalMemoryGpu::new(dev, n_cols, cells_per_col).expect("tm init");
|
| 344 |
+
tm.reset().expect("tm reset");
|
| 345 |
+
(sp, tm)
|
| 346 |
+
};
|
| 347 |
+
|
| 348 |
+
// Deterministic SDR sequence.
|
| 349 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
|
| 350 |
+
let t = 32usize;
|
| 351 |
+
let mut inputs_flat = vec![0u8; t * bits];
|
| 352 |
+
for i in 0..t {
|
| 353 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 354 |
+
inputs_flat[i * bits..(i + 1) * bits].copy_from_slice(&sdr);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
// ---- Path A: owned CudaSlice (numpy-equivalent path) ----
|
| 358 |
+
let (mut sp_a, mut tm_a) = build();
|
| 359 |
+
let dev_a = sp_a.dev_ref().clone();
|
| 360 |
+
let inputs_a: CudaSlice<u8> = dev_a.htod_sync_copy(&inputs_flat).expect("htod");
|
| 361 |
+
let mut cols_a = dev_a.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_a");
|
| 362 |
+
let mut anom_a = dev_a.alloc_zeros::<f32>(t).expect("alloc anom_a");
|
| 363 |
+
sp_a.step_batch_with_tm(&inputs_a, t, bits, false, &mut cols_a, &mut anom_a, &mut tm_a)
|
| 364 |
+
.expect("owned step_batch_with_tm");
|
| 365 |
+
dev_a.synchronize().expect("sync a");
|
| 366 |
+
let cols_a_host: Vec<u8> = dev_a.dtoh_sync_copy(&cols_a).expect("d2h cols_a");
|
| 367 |
+
let anom_a_host: Vec<f32> = dev_a.dtoh_sync_copy(&anom_a).expect("d2h anom_a");
|
| 368 |
+
|
| 369 |
+
// ---- Path B: borrowed device pointers via upgrade_device_ptr ----
|
| 370 |
+
// We allocate fresh owned CudaSlices on a fresh device, then take their
|
| 371 |
+
// raw ptrs and re-wrap as ManuallyDrop borrowed views — mimicking what
|
| 372 |
+
// `step_many_cuda` does with torch-owned CUDA memory.
|
| 373 |
+
let (mut sp_b, mut tm_b) = build();
|
| 374 |
+
let dev_b = sp_b.dev_ref().clone();
|
| 375 |
+
let inputs_b_owned: CudaSlice<u8> = dev_b.htod_sync_copy(&inputs_flat).expect("htod");
|
| 376 |
+
let cols_b_owned = dev_b.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_b");
|
| 377 |
+
let anom_b_owned = dev_b.alloc_zeros::<f32>(t).expect("alloc anom_b");
|
| 378 |
+
|
| 379 |
+
// Extract raw CUdeviceptrs (and leak the owners so their Drop doesn't free).
|
| 380 |
+
let inputs_ptr = inputs_b_owned.leak();
|
| 381 |
+
let cols_ptr = cols_b_owned.leak();
|
| 382 |
+
let anom_ptr = anom_b_owned.leak();
|
| 383 |
+
|
| 384 |
+
// Re-wrap as borrowed views.
|
| 385 |
+
let inputs_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) });
|
| 386 |
+
let mut cols_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) });
|
| 387 |
+
let mut anom_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) });
|
| 388 |
+
|
| 389 |
+
sp_b.step_batch_with_tm(&inputs_b, t, bits, false, &mut cols_b, &mut anom_b, &mut tm_b)
|
| 390 |
+
.expect("borrowed step_batch_with_tm");
|
| 391 |
+
dev_b.synchronize().expect("sync b");
|
| 392 |
+
// `ManuallyDrop` doesn't auto-coerce to `&CudaSlice<T>` for the DevicePtr
|
| 393 |
+
// trait bound on `dtoh_sync_copy`; explicit deref.
|
| 394 |
+
let cols_b_host: Vec<u8> = dev_b.dtoh_sync_copy(&*cols_b).expect("d2h cols_b");
|
| 395 |
+
let anom_b_host: Vec<f32> = dev_b.dtoh_sync_copy(&*anom_b).expect("d2h anom_b");
|
| 396 |
+
|
| 397 |
+
// Re-own so Drop actually frees (we leaked above).
|
| 398 |
+
let _inputs_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) };
|
| 399 |
+
let _cols_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) };
|
| 400 |
+
let _anom_owned_again = unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) };
|
| 401 |
+
|
| 402 |
+
assert_eq!(cols_a_host, cols_b_host, "active-column mask diverges between numpy and CAI paths");
|
| 403 |
+
assert_eq!(anom_a_host.len(), anom_b_host.len());
|
| 404 |
+
for (i, (a, b)) in anom_a_host.iter().zip(anom_b_host.iter()).enumerate() {
|
| 405 |
+
// Anomaly is a pure division of integer counts — bit-exact expected.
|
| 406 |
+
assert!((a - b).abs() < 1e-7, "anomaly mismatch at step {i}: a={a} b={b}");
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/// Fused kernel: threshold activation should converge to near target sparsity
|
| 411 |
+
/// after a short warmup. Acceptance: mean activation rate per step lands in
|
| 412 |
+
/// [0.3*target, 2.5*target] after 500-step warmup. Because the threshold
|
| 413 |
+
/// starts conservative (=2.0) and the per-column adaptation rate is slow
|
| 414 |
+
/// (0.001), we allow a generous band — the test asserts directional
|
| 415 |
+
/// convergence toward the target, not tight matching.
|
| 416 |
+
#[test]
|
| 417 |
+
fn gpu_threshold_converges_to_sparsity() {
|
| 418 |
+
let cfg = SpatialPoolerConfig::default();
|
| 419 |
+
let bits = cfg.input_bits;
|
| 420 |
+
let n_cols = cfg.n_columns;
|
| 421 |
+
let cells_per_col = 32usize;
|
| 422 |
+
let target = cfg.sparsity; // 0.02 = 40 cols expected
|
| 423 |
+
|
| 424 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 111);
|
| 425 |
+
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 426 |
+
let dev = sp.dev_ref().clone();
|
| 427 |
+
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
|
| 428 |
+
let mut fused = FusedState::new(
|
| 429 |
+
dev.clone(),
|
| 430 |
+
n_cols,
|
| 431 |
+
cells_per_col,
|
| 432 |
+
sp.initial_threshold_estimate(),
|
| 433 |
+
).expect("fused init");
|
| 434 |
+
tm.reset().expect("tm reset");
|
| 435 |
+
fused.reset().expect("fused reset");
|
| 436 |
+
|
| 437 |
+
// Warmup: 1000 random 2%-sparse SDRs.
|
| 438 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
|
| 439 |
+
let t_warm = 1000usize;
|
| 440 |
+
let mut inputs = vec![0u8; t_warm * bits];
|
| 441 |
+
for ti in 0..t_warm {
|
| 442 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 443 |
+
inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 444 |
+
}
|
| 445 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod");
|
| 446 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t_warm * n_cols).expect("alloc cols");
|
| 447 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t_warm).expect("alloc anom");
|
| 448 |
+
launch_fused(
|
| 449 |
+
&mut sp, &mut tm, &mut fused,
|
| 450 |
+
&inputs_dev, &mut cols_dev, &mut anom_dev,
|
| 451 |
+
t_warm, bits, true,
|
| 452 |
+
).expect("warmup launch");
|
| 453 |
+
dev.synchronize().expect("sync");
|
| 454 |
+
|
| 455 |
+
// Measurement pass: another 200 steps, measure mean activation.
|
| 456 |
+
let t_meas = 200usize;
|
| 457 |
+
let mut meas_inputs = vec![0u8; t_meas * bits];
|
| 458 |
+
for ti in 0..t_meas {
|
| 459 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 460 |
+
meas_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 461 |
+
}
|
| 462 |
+
let meas_dev: CudaSlice<u8> = dev.htod_sync_copy(&meas_inputs).expect("htod meas");
|
| 463 |
+
let mut meas_cols = dev.alloc_zeros::<u8>(t_meas * n_cols).expect("alloc meas cols");
|
| 464 |
+
let mut meas_anom = dev.alloc_zeros::<f32>(t_meas).expect("alloc meas anom");
|
| 465 |
+
launch_fused(
|
| 466 |
+
&mut sp, &mut tm, &mut fused,
|
| 467 |
+
&meas_dev, &mut meas_cols, &mut meas_anom,
|
| 468 |
+
t_meas, bits, true,
|
| 469 |
+
).expect("meas launch");
|
| 470 |
+
dev.synchronize().expect("sync meas");
|
| 471 |
+
|
| 472 |
+
let cols_host: Vec<u8> = dev.dtoh_sync_copy(&meas_cols).expect("d2h");
|
| 473 |
+
let mut step_counts = Vec::with_capacity(t_meas);
|
| 474 |
+
for ti in 0..t_meas {
|
| 475 |
+
let n_on = cols_host[ti*n_cols..(ti+1)*n_cols]
|
| 476 |
+
.iter().filter(|&&b| b != 0).count();
|
| 477 |
+
step_counts.push(n_on);
|
| 478 |
+
}
|
| 479 |
+
let mean_active: f64 = step_counts.iter().map(|&c| c as f64).sum::<f64>()
|
| 480 |
+
/ (t_meas as f64);
|
| 481 |
+
let target_active = target as f64 * n_cols as f64;
|
| 482 |
+
eprintln!(
|
| 483 |
+
"threshold-activation convergence: mean_active/step = {mean_active:.1} \
|
| 484 |
+
(target = {target_active:.1})"
|
| 485 |
+
);
|
| 486 |
+
// Very generous band — we just want to confirm the threshold loop is
|
| 487 |
+
// functioning (not diverged to 0 or to all-active).
|
| 488 |
+
assert!(
|
| 489 |
+
mean_active >= 0.25 * target_active && mean_active <= 4.0 * target_active,
|
| 490 |
+
"mean active {mean_active:.1} outside [0.25x, 4x] of target {target_active:.1}"
|
| 491 |
+
);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
/// Fused kernel: TM should learn a repeating sequence — anomaly decays.
|
| 495 |
+
#[test]
|
| 496 |
+
fn gpu_fused_tm_anomaly_decays_on_repeating_sequence() {
|
| 497 |
+
let cfg = SpatialPoolerConfig::default();
|
| 498 |
+
let bits = cfg.input_bits;
|
| 499 |
+
let n_cols = cfg.n_columns;
|
| 500 |
+
let cells_per_col = 32usize;
|
| 501 |
+
|
| 502 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271);
|
| 503 |
+
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 504 |
+
let dev = sp.dev_ref().clone();
|
| 505 |
+
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
|
| 506 |
+
let mut fused = FusedState::new(
|
| 507 |
+
dev.clone(),
|
| 508 |
+
n_cols,
|
| 509 |
+
cells_per_col,
|
| 510 |
+
sp.initial_threshold_estimate(),
|
| 511 |
+
).expect("fused init");
|
| 512 |
+
tm.reset().expect("tm reset");
|
| 513 |
+
fused.reset().expect("fused reset");
|
| 514 |
+
|
| 515 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 516 |
+
let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
|
| 517 |
+
let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
|
| 518 |
+
|
| 519 |
+
// Warmup SP threshold calibration with random SDRs first.
|
| 520 |
+
let warm = 300usize;
|
| 521 |
+
let mut warm_inputs = vec![0u8; warm * bits];
|
| 522 |
+
for ti in 0..warm {
|
| 523 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 524 |
+
warm_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 525 |
+
}
|
| 526 |
+
let warm_dev: CudaSlice<u8> = dev.htod_sync_copy(&warm_inputs).expect("htod warm");
|
| 527 |
+
let mut warm_cols = dev.alloc_zeros::<u8>(warm * n_cols).expect("alloc warm cols");
|
| 528 |
+
let mut warm_anom = dev.alloc_zeros::<f32>(warm).expect("alloc warm anom");
|
| 529 |
+
launch_fused(
|
| 530 |
+
&mut sp, &mut tm, &mut fused,
|
| 531 |
+
&warm_dev, &mut warm_cols, &mut warm_anom,
|
| 532 |
+
warm, bits, true,
|
| 533 |
+
).expect("warm launch");
|
| 534 |
+
dev.synchronize().expect("sync warm");
|
| 535 |
+
|
| 536 |
+
// Feed repeating A,B,C sequence for 100 reps.
|
| 537 |
+
let repeats = 100usize;
|
| 538 |
+
let t = repeats * 3;
|
| 539 |
+
let mut inputs = vec![0u8; t * bits];
|
| 540 |
+
for r in 0..repeats {
|
| 541 |
+
for (i, s) in seqs.iter().enumerate() {
|
| 542 |
+
let off = (r*3 + i) * bits;
|
| 543 |
+
inputs[off..off+bits].copy_from_slice(s);
|
| 544 |
+
}
|
| 545 |
+
}
|
| 546 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod rep");
|
| 547 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc rep cols");
|
| 548 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc rep anom");
|
| 549 |
+
launch_fused(
|
| 550 |
+
&mut sp, &mut tm, &mut fused,
|
| 551 |
+
&inputs_dev, &mut cols_dev, &mut anom_dev,
|
| 552 |
+
t, bits, true,
|
| 553 |
+
).expect("rep launch");
|
| 554 |
+
dev.synchronize().expect("sync rep");
|
| 555 |
+
|
| 556 |
+
let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 557 |
+
let early_avg: f32 = anom[3..12].iter().sum::<f32>() / 9.0;
|
| 558 |
+
let late_avg: f32 = anom[(t-9)..t].iter().sum::<f32>() / 9.0;
|
| 559 |
+
eprintln!("fused TM anomaly: early={early_avg:.3} late={late_avg:.3}");
|
| 560 |
+
assert!(
|
| 561 |
+
late_avg < early_avg,
|
| 562 |
+
"anomaly must decay: early={early_avg:.3} late={late_avg:.3}"
|
| 563 |
+
);
|
| 564 |
+
assert!(
|
| 565 |
+
late_avg < 0.5,
|
| 566 |
+
"late anomaly must be < 0.5 (got {late_avg:.3})"
|
| 567 |
+
);
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
#[test]
|
| 571 |
+
fn gpu_sp_yields_k_winners() {
|
| 572 |
+
let cfg = SpatialPoolerConfig::default();
|
| 573 |
+
let bits = cfg.input_bits;
|
| 574 |
+
let n = cfg.n_columns;
|
| 575 |
+
let expected_k = ((cfg.sparsity * n as f32).round() as usize).max(1);
|
| 576 |
+
let cpu = SpatialPooler::new(SpatialPoolerConfig::default(), 7);
|
| 577 |
+
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu).expect("gpu init");
|
| 578 |
+
|
| 579 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
|
| 580 |
+
for _ in 0..10 {
|
| 581 |
+
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 582 |
+
let active = gpu.compute(&sdr_u8, false).expect("gpu compute");
|
| 583 |
+
assert_eq!(active.len(), expected_k);
|
| 584 |
+
// Ensure sorted + unique.
|
| 585 |
+
for w in active.windows(2) {
|
| 586 |
+
assert!(w[0] < w[1], "duplicate or out-of-order winner indices");
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
#[test]
|
| 592 |
+
fn fused_launch_plan_uses_cooperative_grid_sync() {
|
| 593 |
+
let plan = plan_fused_launch(30, true, 30, None).expect("cooperative supported");
|
| 594 |
+
assert_eq!(plan.grid_dim_x, 16);
|
| 595 |
+
assert_eq!(plan.cooperative_grid_limit, 30);
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
#[test]
|
| 599 |
+
fn fused_launch_plan_scales_to_big_gpu() {
|
| 600 |
+
// H200-like: 132 SMs, high cooperative_grid_limit. Cap still applies.
|
| 601 |
+
let plan = plan_fused_launch(132, true, 1000, None).expect("cooperative supported");
|
| 602 |
+
assert_eq!(plan.grid_dim_x, 16); // capped by default override
|
| 603 |
+
let plan = plan_fused_launch(132, true, 1000, Some(64)).expect("cooperative supported");
|
| 604 |
+
assert_eq!(plan.grid_dim_x, 64); // override raises the cap
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
#[test]
|
| 608 |
+
fn fused_launch_plan_refuses_non_cooperative_devices() {
|
| 609 |
+
// The slow path was removed. Devices without cooperative launch fail fast.
|
| 610 |
+
let err = plan_fused_launch(30, false, 0, None).unwrap_err();
|
| 611 |
+
assert!(err.contains("cooperative launch"));
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
#[test]
|
| 615 |
+
fn fused_grid_cap_env_override_is_honored() {
|
| 616 |
+
let cfg = SpatialPoolerConfig::default();
|
| 617 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 5252);
|
| 618 |
+
let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 619 |
+
let dev = sp.dev_ref().clone();
|
| 620 |
+
|
| 621 |
+
unsafe { std::env::set_var("HTM_FUSED_GRID_CAP", "12"); }
|
| 622 |
+
let fused = FusedState::new(
|
| 623 |
+
dev.clone(),
|
| 624 |
+
cfg.n_columns,
|
| 625 |
+
32usize,
|
| 626 |
+
sp.initial_threshold_estimate(),
|
| 627 |
+
).expect("fused init");
|
| 628 |
+
unsafe { std::env::remove_var("HTM_FUSED_GRID_CAP"); }
|
| 629 |
+
|
| 630 |
+
let sm_count = match dev.attribute(
|
| 631 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
| 632 |
+
) {
|
| 633 |
+
Ok(v) => v as u32,
|
| 634 |
+
Err(_) => 16u32,
|
| 635 |
+
};
|
| 636 |
+
let expected = sm_count.max(1).min(12);
|
| 637 |
+
assert_eq!(
|
| 638 |
+
fused.grid_dim_x,
|
| 639 |
+
expected,
|
| 640 |
+
"fused grid cap env override ignored: expected min(sm_count, 12) = {expected}, got {}",
|
| 641 |
+
fused.grid_dim_x,
|
| 642 |
+
);
|
| 643 |
+
}
|
overlay/htm_rust/src/gpu/tm_gpu.rs
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! GPU Temporal Memory.
|
| 2 |
+
//!
|
| 3 |
+
//! Flat device storage. Pre-allocated segment slab:
|
| 4 |
+
//! n_cells = n_columns * cells_per_column
|
| 5 |
+
//! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
|
| 6 |
+
//! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
|
| 7 |
+
//!
|
| 8 |
+
//! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
|
| 9 |
+
//! MAX_SEGMENTS_PER_CELL = 16
|
| 10 |
+
//! MAX_SYN_PER_SEGMENT = 32
|
| 11 |
+
//!
|
| 12 |
+
//! At n_cells = 65536:
|
| 13 |
+
//! n_segments_max = 1_048_576 (~1M)
|
| 14 |
+
//! n_synapses_max = 33_554_432 (~33M)
|
| 15 |
+
//! Storage:
|
| 16 |
+
//! syn_presyn : u32 × 33M = 128 MB
|
| 17 |
+
//! syn_perm : i16 × 33M = 64 MB
|
| 18 |
+
//! seg_cell : u32 × 1M = 4 MB
|
| 19 |
+
//! seg_syn_n : u32 × 1M = 4 MB
|
| 20 |
+
//! misc bitsets etc ~ <1 MB
|
| 21 |
+
//! -------------------------------
|
| 22 |
+
//! Total per region ~200 MB
|
| 23 |
+
//!
|
| 24 |
+
//! Permanences are stored as i16 scaled by 32767 (→ [0, 32767] represents
|
| 25 |
+
//! [0.0, 1.0]). inc/dec are provided pre-scaled.
|
| 26 |
+
|
| 27 |
+
use std::sync::Arc;
|
| 28 |
+
|
| 29 |
+
use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
|
| 30 |
+
use cudarc::nvrtc::Ptx;
|
| 31 |
+
|
| 32 |
+
/// Packed config struct passed by value to TM kernels to stay under
|
| 33 |
+
/// cudarc's 12-tuple launch limit. Layout must match the C-side
|
| 34 |
+
/// `TmConfig` struct declared in each kernel.
|
| 35 |
+
#[repr(C)]
|
| 36 |
+
#[derive(Clone, Copy)]
|
| 37 |
+
pub struct TmConfig {
|
| 38 |
+
pub activation_threshold: u32,
|
| 39 |
+
pub learning_threshold: u32,
|
| 40 |
+
pub cells_per_column: u32,
|
| 41 |
+
pub synapses_per_segment: u32,
|
| 42 |
+
pub n_segments: u32,
|
| 43 |
+
pub n_cells: u32,
|
| 44 |
+
pub max_segments_per_cell: u32,
|
| 45 |
+
pub max_new_synapses: u32,
|
| 46 |
+
pub conn_thr_i16: i32, // i16 widened to i32 for alignment
|
| 47 |
+
pub perm_inc_i16: i32,
|
| 48 |
+
pub perm_dec_i16: i32,
|
| 49 |
+
pub predicted_seg_dec_i16: i32,
|
| 50 |
+
pub initial_perm_i16: i32,
|
| 51 |
+
pub iter_seed: u32,
|
| 52 |
+
pub n_cols: u32,
|
| 53 |
+
pub bits_words: u32,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
unsafe impl DeviceRepr for TmConfig {}
|
| 57 |
+
|
| 58 |
+
// Embedded PTX.
|
| 59 |
+
const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
|
| 60 |
+
const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
|
| 61 |
+
const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
|
| 62 |
+
const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
|
| 63 |
+
const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
|
| 64 |
+
const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
|
| 65 |
+
const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
|
| 66 |
+
|
| 67 |
+
/// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
|
| 68 |
+
/// n_cells = 2048 × 32 = 65_536
|
| 69 |
+
/// n_segments_max = n_cells × MAX_SEGMENTS_PER_CELL
|
| 70 |
+
/// n_synapses_max = n_segments_max × MAX_SYN_PER_SEGMENT
|
| 71 |
+
///
|
| 72 |
+
/// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
|
| 73 |
+
/// The training loop runs with `reset_each_forward=True`, so segment counts
|
| 74 |
+
/// per window stay well below 32K (typical: ~n_cols new segs per step until
|
| 75 |
+
/// the first matching segment is reused; in a 2048-step window that plateaus
|
| 76 |
+
/// around ~5K total live segments). The 262K ceiling is generous headroom.
|
| 77 |
+
pub const MAX_SEGMENTS_PER_CELL: usize = 4;
|
| 78 |
+
pub const MAX_SYN_PER_SEGMENT: usize = 20;
|
| 79 |
+
|
| 80 |
+
const PERM_SCALE: f32 = 32767.0;
|
| 81 |
+
|
| 82 |
+
fn perm_f32_to_i16(x: f32) -> i16 {
|
| 83 |
+
let clamped = x.clamp(0.0, 1.0);
|
| 84 |
+
(clamped * PERM_SCALE).round() as i16
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
pub struct TemporalMemoryGpu {
|
| 88 |
+
dev: Arc<CudaDevice>,
|
| 89 |
+
|
| 90 |
+
// Config mirror
|
| 91 |
+
pub n_columns: usize,
|
| 92 |
+
pub cells_per_column: usize,
|
| 93 |
+
pub activation_threshold: u32,
|
| 94 |
+
pub learning_threshold: u32,
|
| 95 |
+
pub initial_perm_i16: i16,
|
| 96 |
+
pub conn_thr_i16: i16,
|
| 97 |
+
pub perm_inc_i16: i16,
|
| 98 |
+
pub perm_dec_i16: i16,
|
| 99 |
+
pub predicted_seg_dec_i16: i16,
|
| 100 |
+
pub max_new_synapse_count: u32,
|
| 101 |
+
|
| 102 |
+
// Sizes
|
| 103 |
+
pub n_cells: usize,
|
| 104 |
+
pub n_segments_max: usize,
|
| 105 |
+
pub bits_words: usize, // n_cells / 32
|
| 106 |
+
|
| 107 |
+
// Persistent device buffers
|
| 108 |
+
seg_cell_id: CudaSlice<u32>,
|
| 109 |
+
seg_syn_count: CudaSlice<u32>,
|
| 110 |
+
syn_presyn: CudaSlice<u32>,
|
| 111 |
+
syn_perm: CudaSlice<i16>,
|
| 112 |
+
cell_seg_count: CudaSlice<u32>,
|
| 113 |
+
|
| 114 |
+
cell_active_bits: CudaSlice<u32>,
|
| 115 |
+
cell_winner_bits: CudaSlice<u32>,
|
| 116 |
+
cell_predictive_bits: CudaSlice<u32>,
|
| 117 |
+
prev_active_bits: CudaSlice<u32>,
|
| 118 |
+
prev_winner_bits: CudaSlice<u32>,
|
| 119 |
+
|
| 120 |
+
col_predicted: CudaSlice<u8>,
|
| 121 |
+
seg_num_active_conn: CudaSlice<u32>,
|
| 122 |
+
seg_num_active_pot: CudaSlice<u32>,
|
| 123 |
+
unpredicted_count: CudaSlice<u32>,
|
| 124 |
+
burst_cols_flat: CudaSlice<u32>,
|
| 125 |
+
burst_cols_count: CudaSlice<u32>,
|
| 126 |
+
col_best_match: CudaSlice<u32>,
|
| 127 |
+
|
| 128 |
+
iter_counter: u32,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
impl TemporalMemoryGpu {
|
| 132 |
+
pub fn new(
|
| 133 |
+
dev: Arc<CudaDevice>,
|
| 134 |
+
n_columns: usize,
|
| 135 |
+
cells_per_column: usize,
|
| 136 |
+
) -> Result<Self, DriverError> {
|
| 137 |
+
let n_cells = n_columns * cells_per_column;
|
| 138 |
+
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
|
| 139 |
+
let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
|
| 140 |
+
let bits_words = n_cells / 32;
|
| 141 |
+
|
| 142 |
+
// Numenta defaults.
|
| 143 |
+
let activation_threshold = 15u32;
|
| 144 |
+
let learning_threshold = 13u32;
|
| 145 |
+
let initial_perm_i16 = perm_f32_to_i16(0.21);
|
| 146 |
+
let conn_thr_i16 = perm_f32_to_i16(0.50);
|
| 147 |
+
let perm_inc_i16 = perm_f32_to_i16(0.10);
|
| 148 |
+
let perm_dec_i16 = perm_f32_to_i16(0.10);
|
| 149 |
+
let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
|
| 150 |
+
let max_new_synapse_count = 20u32;
|
| 151 |
+
|
| 152 |
+
// Allocate buffers.
|
| 153 |
+
let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
|
| 154 |
+
let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
|
| 155 |
+
let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 156 |
+
let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
|
| 157 |
+
let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
|
| 158 |
+
let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
|
| 159 |
+
|
| 160 |
+
let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 161 |
+
let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 162 |
+
let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 163 |
+
let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 164 |
+
let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 165 |
+
|
| 166 |
+
let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
|
| 167 |
+
let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 168 |
+
let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 169 |
+
let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
|
| 170 |
+
// Bursting columns for one step bounded by n_columns.
|
| 171 |
+
let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
|
| 172 |
+
let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
|
| 173 |
+
let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
|
| 174 |
+
|
| 175 |
+
// Load PTX modules.
|
| 176 |
+
let modules = [
|
| 177 |
+
("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
|
| 178 |
+
("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
|
| 179 |
+
("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
|
| 180 |
+
("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
|
| 181 |
+
("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
|
| 182 |
+
("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
|
| 183 |
+
("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
|
| 184 |
+
];
|
| 185 |
+
for (modname, ptx, fnname) in modules {
|
| 186 |
+
if dev.get_func(modname, fnname).is_none() {
|
| 187 |
+
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
Ok(Self {
|
| 192 |
+
dev,
|
| 193 |
+
n_columns,
|
| 194 |
+
cells_per_column,
|
| 195 |
+
activation_threshold,
|
| 196 |
+
learning_threshold,
|
| 197 |
+
initial_perm_i16,
|
| 198 |
+
conn_thr_i16,
|
| 199 |
+
perm_inc_i16,
|
| 200 |
+
perm_dec_i16,
|
| 201 |
+
predicted_seg_dec_i16,
|
| 202 |
+
max_new_synapse_count,
|
| 203 |
+
n_cells,
|
| 204 |
+
n_segments_max,
|
| 205 |
+
bits_words,
|
| 206 |
+
seg_cell_id,
|
| 207 |
+
seg_syn_count,
|
| 208 |
+
syn_presyn,
|
| 209 |
+
syn_perm,
|
| 210 |
+
cell_seg_count,
|
| 211 |
+
cell_active_bits,
|
| 212 |
+
cell_winner_bits,
|
| 213 |
+
cell_predictive_bits,
|
| 214 |
+
prev_active_bits,
|
| 215 |
+
prev_winner_bits,
|
| 216 |
+
col_predicted,
|
| 217 |
+
seg_num_active_conn,
|
| 218 |
+
seg_num_active_pot,
|
| 219 |
+
unpredicted_count,
|
| 220 |
+
burst_cols_flat,
|
| 221 |
+
burst_cols_count,
|
| 222 |
+
col_best_match,
|
| 223 |
+
iter_counter: 0,
|
| 224 |
+
})
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// --- Fused-path accessors ---
|
| 228 |
+
pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
|
| 229 |
+
pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
|
| 230 |
+
pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
|
| 231 |
+
pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
|
| 232 |
+
pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
|
| 233 |
+
|
| 234 |
+
/// Hard reset — clear everything (predictive + active + segments).
|
| 235 |
+
pub fn reset(&mut self) -> Result<(), DriverError> {
|
| 236 |
+
// Restore "unused" sentinel in seg_cell_id.
|
| 237 |
+
let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
|
| 238 |
+
self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
|
| 239 |
+
self.dev.memset_zeros(&mut self.seg_syn_count)?;
|
| 240 |
+
self.dev.memset_zeros(&mut self.cell_seg_count)?;
|
| 241 |
+
self.dev.memset_zeros(&mut self.cell_active_bits)?;
|
| 242 |
+
self.dev.memset_zeros(&mut self.cell_winner_bits)?;
|
| 243 |
+
self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
|
| 244 |
+
self.dev.memset_zeros(&mut self.prev_active_bits)?;
|
| 245 |
+
self.dev.memset_zeros(&mut self.prev_winner_bits)?;
|
| 246 |
+
self.dev.memset_zeros(&mut self.col_best_match)?;
|
| 247 |
+
self.iter_counter = 0;
|
| 248 |
+
Ok(())
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
fn build_cfg(&self) -> TmConfig {
|
| 252 |
+
TmConfig {
|
| 253 |
+
activation_threshold: self.activation_threshold,
|
| 254 |
+
learning_threshold: self.learning_threshold,
|
| 255 |
+
cells_per_column: self.cells_per_column as u32,
|
| 256 |
+
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 257 |
+
n_segments: self.n_segments_max as u32,
|
| 258 |
+
n_cells: self.n_cells as u32,
|
| 259 |
+
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 260 |
+
max_new_synapses: self.max_new_synapse_count,
|
| 261 |
+
conn_thr_i16: self.conn_thr_i16 as i32,
|
| 262 |
+
perm_inc_i16: self.perm_inc_i16 as i32,
|
| 263 |
+
perm_dec_i16: self.perm_dec_i16 as i32,
|
| 264 |
+
predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
|
| 265 |
+
initial_perm_i16: self.initial_perm_i16 as i32,
|
| 266 |
+
iter_seed: self.iter_counter,
|
| 267 |
+
n_cols: self.n_columns as u32,
|
| 268 |
+
bits_words: self.bits_words as u32,
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
/// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
|
| 273 |
+
/// on device) and writes `anomaly_out[t_slot]`.
|
| 274 |
+
pub fn step(
|
| 275 |
+
&mut self,
|
| 276 |
+
sp_active_mask: &CudaSlice<u8>,
|
| 277 |
+
anomaly_out: &mut CudaSlice<f32>,
|
| 278 |
+
t_slot: u32,
|
| 279 |
+
learn: bool,
|
| 280 |
+
) -> Result<(), DriverError> {
|
| 281 |
+
let n_cells = self.n_cells;
|
| 282 |
+
let n_cols = self.n_columns;
|
| 283 |
+
|
| 284 |
+
let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
|
| 285 |
+
let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
|
| 286 |
+
let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
|
| 287 |
+
let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
|
| 288 |
+
let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
|
| 289 |
+
let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
|
| 290 |
+
let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
|
| 291 |
+
|
| 292 |
+
self.iter_counter = self.iter_counter.wrapping_add(1);
|
| 293 |
+
let cfg_val = self.build_cfg();
|
| 294 |
+
|
| 295 |
+
// 0. Per-step reset.
|
| 296 |
+
let reset_words = self.bits_words.max(n_cols);
|
| 297 |
+
let reset_cfg = LaunchConfig {
|
| 298 |
+
grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
|
| 299 |
+
block_dim: (256, 1, 1),
|
| 300 |
+
shared_mem_bytes: 0,
|
| 301 |
+
};
|
| 302 |
+
unsafe {
|
| 303 |
+
reset_fn.clone().launch(
|
| 304 |
+
reset_cfg,
|
| 305 |
+
(
|
| 306 |
+
&mut self.cell_active_bits,
|
| 307 |
+
&mut self.cell_winner_bits,
|
| 308 |
+
&mut self.cell_predictive_bits,
|
| 309 |
+
&mut self.prev_active_bits,
|
| 310 |
+
&mut self.prev_winner_bits,
|
| 311 |
+
&mut self.col_predicted,
|
| 312 |
+
&mut self.unpredicted_count,
|
| 313 |
+
&mut self.burst_cols_count,
|
| 314 |
+
&mut self.col_best_match,
|
| 315 |
+
self.bits_words as u32,
|
| 316 |
+
n_cols as u32,
|
| 317 |
+
),
|
| 318 |
+
)?;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
// 1. Predict (grid = n_cells; each block iterates its cell's segments).
|
| 322 |
+
let predict_cfg = LaunchConfig {
|
| 323 |
+
grid_dim: (n_cells as u32, 1, 1),
|
| 324 |
+
block_dim: (32, 1, 1),
|
| 325 |
+
shared_mem_bytes: 0,
|
| 326 |
+
};
|
| 327 |
+
unsafe {
|
| 328 |
+
predict_fn.clone().launch(
|
| 329 |
+
predict_cfg,
|
| 330 |
+
(
|
| 331 |
+
&self.seg_cell_id,
|
| 332 |
+
&self.seg_syn_count,
|
| 333 |
+
&self.syn_presyn,
|
| 334 |
+
&self.syn_perm,
|
| 335 |
+
&self.prev_active_bits,
|
| 336 |
+
&mut self.cell_predictive_bits,
|
| 337 |
+
&mut self.col_predicted,
|
| 338 |
+
&mut self.seg_num_active_conn,
|
| 339 |
+
&mut self.seg_num_active_pot,
|
| 340 |
+
&mut self.col_best_match,
|
| 341 |
+
&self.cell_seg_count,
|
| 342 |
+
cfg_val,
|
| 343 |
+
),
|
| 344 |
+
)?;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// 2. Activate.
|
| 348 |
+
let activate_cfg = LaunchConfig {
|
| 349 |
+
grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
|
| 350 |
+
block_dim: (256, 1, 1),
|
| 351 |
+
shared_mem_bytes: 0,
|
| 352 |
+
};
|
| 353 |
+
unsafe {
|
| 354 |
+
activate_fn.clone().launch(
|
| 355 |
+
activate_cfg,
|
| 356 |
+
(
|
| 357 |
+
sp_active_mask,
|
| 358 |
+
&self.col_predicted,
|
| 359 |
+
&self.cell_predictive_bits,
|
| 360 |
+
&mut self.cell_active_bits,
|
| 361 |
+
&mut self.cell_winner_bits,
|
| 362 |
+
&mut self.unpredicted_count,
|
| 363 |
+
&mut self.burst_cols_flat,
|
| 364 |
+
&mut self.burst_cols_count,
|
| 365 |
+
cfg_val,
|
| 366 |
+
),
|
| 367 |
+
)?;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// 3. Anomaly.
|
| 371 |
+
let anom_cfg = LaunchConfig {
|
| 372 |
+
grid_dim: (1, 1, 1),
|
| 373 |
+
block_dim: (256, 1, 1),
|
| 374 |
+
shared_mem_bytes: 0,
|
| 375 |
+
};
|
| 376 |
+
unsafe {
|
| 377 |
+
anom_fn.clone().launch(
|
| 378 |
+
anom_cfg,
|
| 379 |
+
(
|
| 380 |
+
sp_active_mask,
|
| 381 |
+
&self.unpredicted_count,
|
| 382 |
+
anomaly_out,
|
| 383 |
+
t_slot,
|
| 384 |
+
n_cols as u32,
|
| 385 |
+
),
|
| 386 |
+
)?;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
if learn {
|
| 390 |
+
// 4. Reinforce (grid = n_cells).
|
| 391 |
+
let learn_cfg = LaunchConfig {
|
| 392 |
+
grid_dim: (n_cells as u32, 1, 1),
|
| 393 |
+
block_dim: (32, 1, 1),
|
| 394 |
+
shared_mem_bytes: 0,
|
| 395 |
+
};
|
| 396 |
+
unsafe {
|
| 397 |
+
learn_fn.clone().launch(
|
| 398 |
+
learn_cfg,
|
| 399 |
+
(
|
| 400 |
+
&self.seg_cell_id,
|
| 401 |
+
&self.seg_syn_count,
|
| 402 |
+
&self.syn_presyn,
|
| 403 |
+
&mut self.syn_perm,
|
| 404 |
+
&self.seg_num_active_conn,
|
| 405 |
+
&self.prev_active_bits,
|
| 406 |
+
sp_active_mask,
|
| 407 |
+
&self.col_predicted,
|
| 408 |
+
&self.cell_seg_count,
|
| 409 |
+
cfg_val,
|
| 410 |
+
),
|
| 411 |
+
)?;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
// 5. Punish.
|
| 415 |
+
unsafe {
|
| 416 |
+
punish_fn.clone().launch(
|
| 417 |
+
learn_cfg,
|
| 418 |
+
(
|
| 419 |
+
&self.seg_cell_id,
|
| 420 |
+
&self.seg_syn_count,
|
| 421 |
+
&self.syn_presyn,
|
| 422 |
+
&mut self.syn_perm,
|
| 423 |
+
&self.seg_num_active_pot,
|
| 424 |
+
&self.prev_active_bits,
|
| 425 |
+
sp_active_mask,
|
| 426 |
+
&self.cell_seg_count,
|
| 427 |
+
cfg_val,
|
| 428 |
+
),
|
| 429 |
+
)?;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// 6. Grow.
|
| 433 |
+
let grow_cfg = LaunchConfig {
|
| 434 |
+
grid_dim: (n_cols as u32, 1, 1),
|
| 435 |
+
block_dim: (32, 1, 1),
|
| 436 |
+
shared_mem_bytes: 0,
|
| 437 |
+
};
|
| 438 |
+
unsafe {
|
| 439 |
+
grow_fn.clone().launch(
|
| 440 |
+
grow_cfg,
|
| 441 |
+
(
|
| 442 |
+
&mut self.seg_cell_id,
|
| 443 |
+
&mut self.seg_syn_count,
|
| 444 |
+
&mut self.syn_presyn,
|
| 445 |
+
&mut self.syn_perm,
|
| 446 |
+
&mut self.cell_seg_count,
|
| 447 |
+
&self.burst_cols_flat,
|
| 448 |
+
&self.burst_cols_count,
|
| 449 |
+
&self.prev_winner_bits,
|
| 450 |
+
&self.prev_active_bits,
|
| 451 |
+
&self.col_best_match,
|
| 452 |
+
cfg_val,
|
| 453 |
+
),
|
| 454 |
+
)?;
|
| 455 |
+
}
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
Ok(())
|
| 459 |
+
}
|
| 460 |
+
}
|
overlay/hydra/eval.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation: factual probes + sampled factual English scoring.
|
| 2 |
+
|
| 3 |
+
Extracted from train.py (W1 modularization). Semantics unchanged.
|
| 4 |
+
|
| 5 |
+
Perf optimizations (eval_perf_fix):
|
| 6 |
+
- Probe mode: single forward per prompt instead of autoregressive gen
|
| 7 |
+
- Batch decode: all GPU work first, all CPU decode after
|
| 8 |
+
- Batched factual probes: single padded forward instead of N sequential
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import re as _re
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS
|
| 19 |
+
|
| 20 |
+
# Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
|
| 21 |
+
# the original autoregressive generation path.
|
| 22 |
+
FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
|
| 23 |
+
|
| 24 |
+
FACTUAL_EVAL = [
|
| 25 |
+
# Hard factual recall — requires specific knowledge memorization
|
| 26 |
+
("The capital of France is", ["Paris", "paris"]),
|
| 27 |
+
("Water boils at", ["100", "boiling"]),
|
| 28 |
+
("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
|
| 29 |
+
# Easier completions — common collocations / patterns the model may pick up
|
| 30 |
+
("Once upon a", ["time"]),
|
| 31 |
+
("Hello, my name", ["is", "'s"]),
|
| 32 |
+
("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
|
| 33 |
+
("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
|
| 34 |
+
# Original hard ones kept for completeness
|
| 35 |
+
("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
|
| 36 |
+
("Two plus two equals", ["4", "four"]),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
_FACTUAL_PROBES = [
|
| 40 |
+
"The capital of France is",
|
| 41 |
+
"Water boils at",
|
| 42 |
+
"The largest planet in our solar system is",
|
| 43 |
+
"The speed of light is approximately",
|
| 44 |
+
"Shakespeare wrote",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
|
| 49 |
+
"""Top-5 next-token predictions for canonical factual prompts.
|
| 50 |
+
|
| 51 |
+
Batched: pads all prompts into a single forward pass instead of N
|
| 52 |
+
sequential passes.
|
| 53 |
+
"""
|
| 54 |
+
print("\n--- Factual Probes ---")
|
| 55 |
+
model.eval()
|
| 56 |
+
|
| 57 |
+
# Process probes one at a time to avoid cooperative launch limit
|
| 58 |
+
# (batched forward with B=len(probes) can exceed SM residency cap).
|
| 59 |
+
for prompt_text in _FACTUAL_PROBES:
|
| 60 |
+
ids = tokenizer.encode(prompt_text)
|
| 61 |
+
x = torch.tensor([ids], device=device)
|
| 62 |
+
with torch.no_grad(), autocast_ctx:
|
| 63 |
+
logits = model(x)
|
| 64 |
+
probs = torch.softmax(logits[0, -1].float(), dim=-1)
|
| 65 |
+
top5 = torch.topk(probs, 5)
|
| 66 |
+
completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
|
| 67 |
+
probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
|
| 68 |
+
print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
|
| 69 |
+
print("--- End Factual Probes ---\n")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Probe mode: single forward per prompt (Fix D)
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
|
| 77 |
+
"""Fast probe mode: for each (prompt, answers), encode prompt + each answer
|
| 78 |
+
candidate as a single sequence, do ONE forward pass, and check if the model's
|
| 79 |
+
argmax at the last prompt token matches the first answer token.
|
| 80 |
+
|
| 81 |
+
Falls back to checking top-K predictions to be generous (same as gen mode
|
| 82 |
+
which samples multiple temperatures).
|
| 83 |
+
"""
|
| 84 |
+
print("---")
|
| 85 |
+
print("factual_english_samples: (probe mode)")
|
| 86 |
+
model.eval()
|
| 87 |
+
hits = 0
|
| 88 |
+
|
| 89 |
+
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 90 |
+
for prompt, answers in FACTUAL_EVAL:
|
| 91 |
+
prompt_ids = tokenizer.encode(prompt)
|
| 92 |
+
prompt_len = len(prompt_ids)
|
| 93 |
+
x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
|
| 94 |
+
logits = model(x, targets=None)
|
| 95 |
+
# logits shape: [1, seq_len, vocab] or [1, vocab]
|
| 96 |
+
if logits.dim() == 3:
|
| 97 |
+
last_logits = logits[0, -1, :]
|
| 98 |
+
else:
|
| 99 |
+
last_logits = logits[0]
|
| 100 |
+
|
| 101 |
+
probs = torch.softmax(last_logits.float(), dim=-1)
|
| 102 |
+
# Check top-K predictions (generous: K=20 to match multi-sample gen)
|
| 103 |
+
top_k = min(20, probs.shape[-1])
|
| 104 |
+
top_ids = torch.topk(probs, top_k).indices.tolist()
|
| 105 |
+
top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
|
| 106 |
+
|
| 107 |
+
answers_lower = [a.lower() for a in answers]
|
| 108 |
+
any_hit = any(
|
| 109 |
+
any(a in tok for a in answers_lower)
|
| 110 |
+
for tok in top_tokens
|
| 111 |
+
)
|
| 112 |
+
if any_hit:
|
| 113 |
+
hits += 1
|
| 114 |
+
|
| 115 |
+
best_completion = tokenizer.decode([top_ids[0]])
|
| 116 |
+
print(f" prompt: {prompt!r}")
|
| 117 |
+
print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
|
| 118 |
+
print(f" hit: {any_hit} (probe top-{top_k})")
|
| 119 |
+
|
| 120 |
+
score = hits / len(FACTUAL_EVAL)
|
| 121 |
+
print("---")
|
| 122 |
+
print(f"factual_english_score: {score:.4f}")
|
| 123 |
+
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
|
| 124 |
+
return score, hits, len(FACTUAL_EVAL)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# Gen mode: original autoregressive path (Fix F: batch decode)
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
|
| 132 |
+
"""Original autoregressive generation path with batch decode optimization:
|
| 133 |
+
all GPU work runs first, then all CPU decoding happens after."""
|
| 134 |
+
print("---")
|
| 135 |
+
print("factual_english_samples: (gen mode)")
|
| 136 |
+
model.eval()
|
| 137 |
+
|
| 138 |
+
num_samples = FACTUAL_SAMPLES
|
| 139 |
+
batch = FACTUAL_BATCH
|
| 140 |
+
gen_tokens = FACTUAL_GEN_TOKENS
|
| 141 |
+
temps = [0.7, 0.9, 1.1]
|
| 142 |
+
hits = 0
|
| 143 |
+
|
| 144 |
+
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 145 |
+
for prompt, answers in FACTUAL_EVAL:
|
| 146 |
+
ids = tokenizer.encode(prompt)
|
| 147 |
+
answers_lower = [a.lower() for a in answers]
|
| 148 |
+
# Collect all generated token sequences on GPU first
|
| 149 |
+
all_rows: list[list[int]] = []
|
| 150 |
+
samples_done = 0
|
| 151 |
+
batch_idx = 0
|
| 152 |
+
while samples_done < num_samples:
|
| 153 |
+
b = min(batch, num_samples - samples_done)
|
| 154 |
+
temp = temps[batch_idx % len(temps)]
|
| 155 |
+
batch_idx += 1
|
| 156 |
+
ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
|
| 157 |
+
for _ in range(gen_tokens):
|
| 158 |
+
logits = model(ctx, targets=None)
|
| 159 |
+
next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
|
| 160 |
+
probs = torch.softmax(next_logits.float() / temp, dim=-1)
|
| 161 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
| 162 |
+
ctx = torch.cat([ctx, next_id], dim=1)
|
| 163 |
+
if ctx.size(1) >= max_seq_len:
|
| 164 |
+
break
|
| 165 |
+
# Transfer to CPU in one shot, no per-row sync
|
| 166 |
+
all_rows.extend(ctx.cpu().tolist())
|
| 167 |
+
samples_done += b
|
| 168 |
+
|
| 169 |
+
# CPU-side batch decode — no GPU sync between decodes
|
| 170 |
+
any_hit = False
|
| 171 |
+
first_gen = None
|
| 172 |
+
hit_gen = None
|
| 173 |
+
for row in all_rows:
|
| 174 |
+
generated = tokenizer.decode(row)
|
| 175 |
+
continuation = generated[len(prompt):].strip()
|
| 176 |
+
_words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
|
| 177 |
+
hit = any(a in _words for a in answers_lower)
|
| 178 |
+
if first_gen is None:
|
| 179 |
+
first_gen = generated
|
| 180 |
+
if hit:
|
| 181 |
+
any_hit = True
|
| 182 |
+
if hit_gen is None:
|
| 183 |
+
hit_gen = generated
|
| 184 |
+
if any_hit:
|
| 185 |
+
hits += 1
|
| 186 |
+
print(f" prompt: {prompt!r}")
|
| 187 |
+
print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}")
|
| 188 |
+
print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
|
| 189 |
+
if hit_gen is not None and hit_gen != first_gen:
|
| 190 |
+
print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}")
|
| 191 |
+
|
| 192 |
+
score = hits / len(FACTUAL_EVAL)
|
| 193 |
+
print("---")
|
| 194 |
+
print(f"factual_english_score: {score:.4f}")
|
| 195 |
+
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
|
| 196 |
+
return score, hits, len(FACTUAL_EVAL)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ---------------------------------------------------------------------------
|
| 200 |
+
# Public entry point
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def run_factual_english(model, tokenizer, max_seq_len: int):
|
| 204 |
+
"""Dispatch to probe (fast, default) or gen (original) mode.
|
| 205 |
+
|
| 206 |
+
Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
|
| 207 |
+
"""
|
| 208 |
+
if FACTUAL_MODE == "gen":
|
| 209 |
+
return _run_factual_english_gen(model, tokenizer, max_seq_len)
|
| 210 |
+
return _run_factual_english_probe(model, tokenizer, max_seq_len)
|
overlay/hydra/model.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PostSemClawModel — full-architecture model assembly.
|
| 2 |
+
|
| 3 |
+
Extracted from the monolithic train.py (W1 modularization). Semantics
|
| 4 |
+
unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from
|
| 5 |
+
`hydra.optimizer`.
|
| 6 |
+
|
| 7 |
+
Triton kernel integration status (Phase 2):
|
| 8 |
+
HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses
|
| 9 |
+
LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block
|
| 10 |
+
uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside
|
| 11 |
+
the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing
|
| 12 |
+
would require either (a) monkey-patching RMSNormGated + intercepting the
|
| 13 |
+
fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a
|
| 14 |
+
full custom Mamba3Block reimplementation. Both are out of scope for
|
| 15 |
+
Phase 2. The kernel is validated standalone; integration deferred to
|
| 16 |
+
Phase 3 when HYDRA moves to a custom SSM block.
|
| 17 |
+
|
| 18 |
+
HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements
|
| 19 |
+
exponential-trapezoidal discretization as a sequential scan. mamba-ssm's
|
| 20 |
+
Mamba3 block delegates the entire scan + gating + output projection to
|
| 21 |
+
mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing
|
| 22 |
+
it would require decomposing the combined kernel into constituent ops
|
| 23 |
+
and substituting only the scan — not feasible without a custom block.
|
| 24 |
+
Same Phase 3 gate as above.
|
| 25 |
+
|
| 26 |
+
Both env vars are accepted but currently no-ops (gates read, logged, but
|
| 27 |
+
the code path is unchanged). This avoids silent regression if someone
|
| 28 |
+
sets them expecting a speedup.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
|
| 39 |
+
from mamba_ssm import Mamba3
|
| 40 |
+
|
| 41 |
+
from subsystems.hestia_mini import HestiaQAT
|
| 42 |
+
from subsystems.htm import HTMLayer
|
| 43 |
+
from subsystems.mhc_mini import ManifoldHyperConnection
|
| 44 |
+
from subsystems.sdr_semantic import SemanticFoldingSDR
|
| 45 |
+
|
| 46 |
+
from hydra.engram import GPUEngram
|
| 47 |
+
from hydra.optimizer import MuonAdamW
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def norm(x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
"""RMSNorm over the last dim — stateless, autocast-friendly."""
|
| 52 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class PostSemClawModel(nn.Module):
|
| 56 |
+
"""Full Post-SEM-Claw model assembly.
|
| 57 |
+
|
| 58 |
+
Architecture:
|
| 59 |
+
Token Embedding -> [Mamba3 + residual] x n_layer
|
| 60 |
+
-> SDR + Engram (at configured layer) -> norm -> LM head
|
| 61 |
+
|
| 62 |
+
Interface (must match prepare.py evaluate_bpb):
|
| 63 |
+
model(x, y, reduction='none').view(-1) -> per-token losses
|
| 64 |
+
model(x, y, reduction='mean') -> scalar loss
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, config):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.config = config
|
| 70 |
+
|
| 71 |
+
# Token embedding
|
| 72 |
+
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
| 73 |
+
|
| 74 |
+
# Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks.
|
| 75 |
+
# RoPE is applied internally by the Mamba3 CUDA kernel via the Angles
|
| 76 |
+
# parameter; external cos/sin buffers are not needed.
|
| 77 |
+
self.blocks = nn.ModuleList([
|
| 78 |
+
Mamba3(
|
| 79 |
+
d_model=config.d_model,
|
| 80 |
+
d_state=config.d_state,
|
| 81 |
+
expand=config.expand,
|
| 82 |
+
headdim=config.headdim,
|
| 83 |
+
is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel
|
| 84 |
+
chunk_size=64, # upstream-recommended SISO chunk; 16 violated tl.dot M>=16 constraint
|
| 85 |
+
is_outproj_norm=False,
|
| 86 |
+
dtype=torch.bfloat16,
|
| 87 |
+
)
|
| 88 |
+
for _ in range(config.n_layer)
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
# Full-architecture SDR: offline semantic retina + STE (no-bypass).
|
| 92 |
+
self.sdr_semantic = SemanticFoldingSDR(
|
| 93 |
+
vocab_size=config.vocab_size,
|
| 94 |
+
n_bits=config.sdr_n_bits,
|
| 95 |
+
target_active=config.sdr_target_active,
|
| 96 |
+
delta_rank=config.sdr_delta_rank,
|
| 97 |
+
som_warmup_steps=config.sdr_som_warmup,
|
| 98 |
+
som_update_interval=config.sdr_som_interval,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# HTM spatial pooler + temporal memory (Rust, Hebbian).
|
| 102 |
+
self.htm = HTMLayer(
|
| 103 |
+
input_bits=config.sdr_n_bits,
|
| 104 |
+
n_columns=config.htm_n_columns,
|
| 105 |
+
cells_per_column=config.htm_cells_per_column,
|
| 106 |
+
batch_size=1, # grows lazily to actual B on first forward
|
| 107 |
+
seed=42,
|
| 108 |
+
learn=True,
|
| 109 |
+
reset_each_forward=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Gradient bridge: (n_columns + anomaly) -> d_model.
|
| 113 |
+
self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False)
|
| 114 |
+
|
| 115 |
+
# GPU Engram with Hebbian writes — runs EVERY step.
|
| 116 |
+
self.engram = GPUEngram(
|
| 117 |
+
d_model=config.d_model,
|
| 118 |
+
n_columns=config.engram_n_columns,
|
| 119 |
+
max_ngram=3,
|
| 120 |
+
)
|
| 121 |
+
self.engram_layer_idx = config.engram_layer_idx
|
| 122 |
+
|
| 123 |
+
# Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
|
| 124 |
+
self.mhc = nn.ModuleList([
|
| 125 |
+
ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3)
|
| 126 |
+
for _ in range(config.n_layer)
|
| 127 |
+
])
|
| 128 |
+
|
| 129 |
+
# Hestia QAT — ternary weight quantization applied post-optimizer-step.
|
| 130 |
+
self.hestia = HestiaQAT(enabled=True, bits=1.58)
|
| 131 |
+
|
| 132 |
+
# LM head
|
| 133 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 134 |
+
|
| 135 |
+
# Residual dropout
|
| 136 |
+
self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.2")))
|
| 137 |
+
|
| 138 |
+
# Logits soft-capping
|
| 139 |
+
self.softcap = 15.0
|
| 140 |
+
|
| 141 |
+
# Secondary metrics storage
|
| 142 |
+
self._metrics = {}
|
| 143 |
+
|
| 144 |
+
# Per-layer diagnostic panel. Env-gated; zero overhead when off.
|
| 145 |
+
# Emits residual-contribution (delta_ratio), feature std, effective rank,
|
| 146 |
+
# gradient norm per layer; used to identify minimum viable n_layer + find
|
| 147 |
+
# entropy leakage / dead layers. See docs/depth-sweep.md.
|
| 148 |
+
self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1"
|
| 149 |
+
self._diag_step = 0
|
| 150 |
+
self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100"))
|
| 151 |
+
if self._diag_enabled:
|
| 152 |
+
# Gradient-norm backward hooks on each Mamba3 block output.
|
| 153 |
+
for _i, _block in enumerate(self.blocks):
|
| 154 |
+
def _mk_grad_hook(_layer_idx):
|
| 155 |
+
def _hook(module, grad_input, grad_output):
|
| 156 |
+
if grad_output and grad_output[0] is not None:
|
| 157 |
+
g = grad_output[0].detach()
|
| 158 |
+
self._metrics[f'layer_{_layer_idx}_grad_norm'] = float(
|
| 159 |
+
g.pow(2).mean().sqrt().item()
|
| 160 |
+
)
|
| 161 |
+
return _hook
|
| 162 |
+
_block.register_full_backward_hook(_mk_grad_hook(_i))
|
| 163 |
+
|
| 164 |
+
# Forward hooks on each Mamba3 block capture the block's OUTPUT
|
| 165 |
+
# directly. This is the clean measurement: unlike merge_streams()
|
| 166 |
+
# sampling which sees (streams + M*block_output) in bf16 — where
|
| 167 |
+
# small block contributions round to zero against unit-norm
|
| 168 |
+
# residuals — this captures `block_output` itself as produced.
|
| 169 |
+
# Reports both its absolute RMS norm and its ratio to the block
|
| 170 |
+
# INPUT's RMS norm (contribution magnitude relative to the
|
| 171 |
+
# residual it's added to).
|
| 172 |
+
for _i, _block in enumerate(self.blocks):
|
| 173 |
+
def _mk_fwd_hook(_layer_idx):
|
| 174 |
+
def _hook(module, inputs, output):
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
inp = inputs[0].detach().float() if inputs else None
|
| 177 |
+
out = output.detach().float() if isinstance(output, torch.Tensor) else None
|
| 178 |
+
if out is not None:
|
| 179 |
+
out_rms = out.pow(2).mean().sqrt().item()
|
| 180 |
+
self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms)
|
| 181 |
+
if inp is not None:
|
| 182 |
+
in_rms = inp.pow(2).mean().sqrt().item()
|
| 183 |
+
self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms)
|
| 184 |
+
self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float(
|
| 185 |
+
out_rms / (in_rms + 1e-8)
|
| 186 |
+
)
|
| 187 |
+
return _hook
|
| 188 |
+
_block.register_forward_hook(_mk_fwd_hook(_i))
|
| 189 |
+
|
| 190 |
+
# Triton kernel integration gates (Phase 2 — deferred, see module docstring).
|
| 191 |
+
self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1"
|
| 192 |
+
self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1"
|
| 193 |
+
if self._fused_bcnorm or self._fused_ssd:
|
| 194 |
+
import sys
|
| 195 |
+
_active = []
|
| 196 |
+
if self._fused_bcnorm:
|
| 197 |
+
_active.append("HYDRA_FUSED_BCNORM")
|
| 198 |
+
if self._fused_ssd:
|
| 199 |
+
_active.append("HYDRA_FUSED_SSD")
|
| 200 |
+
print(
|
| 201 |
+
f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. "
|
| 202 |
+
f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal "
|
| 203 |
+
f"CUDA kernels). Gates accepted but currently no-ops.",
|
| 204 |
+
file=sys.stderr,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# R6 optional torch.compile on the impl forward. Gated (default OFF).
|
| 208 |
+
if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1":
|
| 209 |
+
self._forward_impl = torch.compile(
|
| 210 |
+
self._forward_impl,
|
| 211 |
+
fullgraph=False,
|
| 212 |
+
dynamic=True,
|
| 213 |
+
mode="default",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def init_weights(self) -> None:
|
| 218 |
+
s = 3 ** 0.5 * self.config.d_model ** -0.5
|
| 219 |
+
|
| 220 |
+
# Move SDR retina indices (plain attribute, not buffer) to same device as params.
|
| 221 |
+
# Required because to_empty() only moves params/buffers, and _retina_indices
|
| 222 |
+
# is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
|
| 223 |
+
device = self.wte.weight.device
|
| 224 |
+
if hasattr(self.sdr_semantic, '_retina_indices'):
|
| 225 |
+
self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
|
| 226 |
+
|
| 227 |
+
# Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
|
| 228 |
+
# vocab=8192; at larger vocabs, smaller std prevents logit blowup.
|
| 229 |
+
# Use std = 1/sqrt(d_model) which scales sensibly with model width.
|
| 230 |
+
import math as _math
|
| 231 |
+
_d_model = self.wte.weight.shape[1]
|
| 232 |
+
wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model))))
|
| 233 |
+
nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std)
|
| 234 |
+
# LM head init: was std=0.001 — PATHOLOGICAL at vocab>=32k because
|
| 235 |
+
# logits collapse to zero, loss locks at log(V)~=11, gradient through
|
| 236 |
+
# head ∝ 1/V is too small to escape. GPT-2 uses std=0.02; LLaMA uses
|
| 237 |
+
# std=1/sqrt(d_model). Pick 0.02 as robust default, env-overridable.
|
| 238 |
+
lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02"))
|
| 239 |
+
nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
|
| 240 |
+
# F8 (NOT APPLIED): Weight tying would save V*D params but current LR
|
| 241 |
+
# groups have embedding_lr=1.0 and unembedding_lr=0.005 × d_model_scale
|
| 242 |
+
# — tying forces the shared tensor under a single LR group and either
|
| 243 |
+
# the embeddings learn 200x too slow (under unembed LR) or the LM head
|
| 244 |
+
# becomes unstable (under embed LR). Short 15-step smoke with tying +
|
| 245 |
+
# embed-group update showed initial loss jump 9 -> 20. Deferred until
|
| 246 |
+
# LR groups are re-tuned; see docs/OPTIMIZATION_PLAN.md Post-plan.
|
| 247 |
+
|
| 248 |
+
for li, block in enumerate(self.blocks):
|
| 249 |
+
if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'):
|
| 250 |
+
nn.init.uniform_(block.in_proj.weight, -s, s)
|
| 251 |
+
if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'):
|
| 252 |
+
# GPT-2 residual init: std = 0.02 / sqrt(2 * n_layer).
|
| 253 |
+
# NOT zeros — zero init makes the block a permanent pass-through
|
| 254 |
+
# (block_out_rms=0, zero gradient flow to SSM internals).
|
| 255 |
+
# With non-zero init the block contributes to the residual stream
|
| 256 |
+
# from step 1, giving the SSM scan actual gradient signal.
|
| 257 |
+
n_layer = self.config.n_layer
|
| 258 |
+
out_std = float(os.environ.get(
|
| 259 |
+
"HYDRA_OUT_PROJ_STD",
|
| 260 |
+
str(0.02 / (2 * n_layer) ** 0.5),
|
| 261 |
+
))
|
| 262 |
+
nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
|
| 263 |
+
|
| 264 |
+
nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
|
| 265 |
+
|
| 266 |
+
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 267 |
+
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 268 |
+
self.wte.to(dtype=torch.bfloat16)
|
| 269 |
+
self.htm_proj.to(dtype=torch.bfloat16)
|
| 270 |
+
self.engram.to(dtype=torch.bfloat16)
|
| 271 |
+
|
| 272 |
+
def estimate_flops(self) -> int:
|
| 273 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 274 |
+
embed_params = self.wte.weight.numel()
|
| 275 |
+
return 6 * (nparams - embed_params)
|
| 276 |
+
|
| 277 |
+
def num_scaling_params(self) -> dict:
|
| 278 |
+
wte = sum(p.numel() for p in self.wte.parameters())
|
| 279 |
+
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 280 |
+
blocks = sum(p.numel() for p in self.blocks.parameters())
|
| 281 |
+
sdr = sum(p.numel() for p in self.sdr_semantic.parameters())
|
| 282 |
+
htm_proj = sum(p.numel() for p in self.htm_proj.parameters())
|
| 283 |
+
engram = sum(p.numel() for p in self.engram.parameters())
|
| 284 |
+
total = sum(p.numel() for p in self.parameters())
|
| 285 |
+
return {
|
| 286 |
+
'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
|
| 287 |
+
'sdr_semantic': sdr, 'htm_proj': htm_proj,
|
| 288 |
+
'engram': engram, 'total': total,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
def get_secondary_metrics(self) -> dict:
|
| 292 |
+
"""Flush any lingering CUDA tensors to host (single sync)."""
|
| 293 |
+
flushed = {}
|
| 294 |
+
for k, v in self._metrics.items():
|
| 295 |
+
if hasattr(v, 'item'):
|
| 296 |
+
try:
|
| 297 |
+
flushed[k] = float(v.item())
|
| 298 |
+
except Exception:
|
| 299 |
+
flushed[k] = v
|
| 300 |
+
else:
|
| 301 |
+
flushed[k] = v
|
| 302 |
+
return flushed
|
| 303 |
+
|
| 304 |
+
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04,
|
| 305 |
+
weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
| 306 |
+
"""Setup MuonAdamW optimizer with per-component LR groups."""
|
| 307 |
+
model_dim = self.config.d_model
|
| 308 |
+
|
| 309 |
+
embedding_params = list(self.wte.parameters())
|
| 310 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 311 |
+
|
| 312 |
+
# Matrix params -> Muon (exactly 2D weight matrices).
|
| 313 |
+
matrix_params = []
|
| 314 |
+
for p in self.blocks.parameters():
|
| 315 |
+
if p.dim() == 2:
|
| 316 |
+
matrix_params.append(p)
|
| 317 |
+
# NOTE (W1 audit REG-2): SemanticFoldingSDR.delta_u / delta_v are
|
| 318 |
+
# currently GRADIENT-DEAD. The forward path uses `binary_only(idx)` for
|
| 319 |
+
# HTM and stores it as `self._last_sdr`, but does NOT route the STE
|
| 320 |
+
# output through any downstream op. Including them in the Muon group
|
| 321 |
+
# burns compute (stack + orthogonalize + lerp) on zero-grad params
|
| 322 |
+
# every step. Excluded here; a later W5 pass can reconnect STE via a
|
| 323 |
+
# gated residual if the SDR signal is wanted back in-graph. The
|
| 324 |
+
# parameters still exist, so no state_dict break.
|
| 325 |
+
# for p in self.sdr_semantic.parameters():
|
| 326 |
+
# if p.dim() == 2:
|
| 327 |
+
# matrix_params.append(p)
|
| 328 |
+
for p in self.htm_proj.parameters():
|
| 329 |
+
if p.dim() == 2:
|
| 330 |
+
matrix_params.append(p)
|
| 331 |
+
for p in self.engram.parameters():
|
| 332 |
+
if p.dim() == 2:
|
| 333 |
+
matrix_params.append(p)
|
| 334 |
+
|
| 335 |
+
# SDR params are intentionally not in any optimizer group — they
|
| 336 |
+
# receive no gradient in the current forward, so any update would be
|
| 337 |
+
# pure noise (weight_decay × lr on a zero-grad param).
|
| 338 |
+
sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters())
|
| 339 |
+
assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
|
| 340 |
+
scalar_params = [
|
| 341 |
+
p for p in self.parameters()
|
| 342 |
+
if id(p) not in assigned and id(p) not in sdr_param_ids
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params)
|
| 346 |
+
total_params = len(list(self.parameters()))
|
| 347 |
+
sdr_excluded = len(list(self.sdr_semantic.parameters()))
|
| 348 |
+
assert total_assigned + sdr_excluded == total_params, (
|
| 349 |
+
f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
|
| 350 |
+
f"{sdr_excluded} vs total {total_params}"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 354 |
+
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 355 |
+
|
| 356 |
+
param_groups = [
|
| 357 |
+
dict(kind='adamw', params=lm_head_params,
|
| 358 |
+
lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas,
|
| 359 |
+
eps=1e-10, weight_decay=0.0),
|
| 360 |
+
dict(kind='adamw', params=embedding_params,
|
| 361 |
+
lr=embedding_lr * dmodel_lr_scale, betas=adam_betas,
|
| 362 |
+
eps=1e-10, weight_decay=0.0),
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
if scalar_params:
|
| 366 |
+
param_groups.append(
|
| 367 |
+
dict(kind='adamw', params=scalar_params,
|
| 368 |
+
lr=scalar_lr * dmodel_lr_scale, betas=adam_betas,
|
| 369 |
+
eps=1e-10, weight_decay=0.0)
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
for shape in sorted({p.shape for p in matrix_params}):
|
| 373 |
+
group_params = [p for p in matrix_params if p.shape == shape]
|
| 374 |
+
param_groups.append(dict(
|
| 375 |
+
kind='muon', params=group_params, lr=matrix_lr,
|
| 376 |
+
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
| 377 |
+
))
|
| 378 |
+
|
| 379 |
+
optimizer = MuonAdamW(param_groups)
|
| 380 |
+
for group in optimizer.param_groups:
|
| 381 |
+
group["initial_lr"] = group["lr"]
|
| 382 |
+
return optimizer
|
| 383 |
+
|
| 384 |
+
def forward(self, idx, targets=None, reduction='mean'):
|
| 385 |
+
"""idx: (B, T) int64. Returns loss if targets given, else logits.
|
| 386 |
+
|
| 387 |
+
Nested bf16 autocast is a no-op when ambient autocast is already on;
|
| 388 |
+
when it's off (e.g. integration tests) we establish the dtype contract.
|
| 389 |
+
"""
|
| 390 |
+
if torch.is_autocast_enabled():
|
| 391 |
+
return self._forward_impl(idx, targets=targets, reduction=reduction)
|
| 392 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 393 |
+
return self._forward_impl(idx, targets=targets, reduction=reduction)
|
| 394 |
+
|
| 395 |
+
def _forward_impl(self, idx, targets=None, reduction='mean'):
|
| 396 |
+
B, T = idx.shape
|
| 397 |
+
|
| 398 |
+
# Diagnostic: per-subsystem CUDA event timing. Env-gated; zero overhead
|
| 399 |
+
# when disabled. Logs one timing line per forward call. Used to isolate
|
| 400 |
+
# which subsystem is the tps bottleneck on paid hardware.
|
| 401 |
+
_profile = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
|
| 402 |
+
if _profile:
|
| 403 |
+
def _ev():
|
| 404 |
+
e = torch.cuda.Event(enable_timing=True)
|
| 405 |
+
e.record()
|
| 406 |
+
return e
|
| 407 |
+
_t0 = _ev()
|
| 408 |
+
else:
|
| 409 |
+
_t0 = None
|
| 410 |
+
|
| 411 |
+
# Compute SDR binary ONCE and reuse for both HTM input and the stash.
|
| 412 |
+
sdr_binary = self.sdr_semantic.binary_only(idx)
|
| 413 |
+
self._last_sdr = sdr_binary # uint8 stash (not bf16 → 256MB avoidance)
|
| 414 |
+
|
| 415 |
+
# HTM subsampling: run HTM on 1 of every N micro-batches within a
|
| 416 |
+
# gradient accumulation step, reuse the cached result for the other
|
| 417 |
+
# N-1 micro-batches. Cooperative launch monopolizes all SMs (grid.sync
|
| 418 |
+
# requires full-grid residency), so HTM and mamba can't overlap via
|
| 419 |
+
# streams. Subsampling removes HTM from most micro-batches' critical
|
| 420 |
+
# path instead.
|
| 421 |
+
#
|
| 422 |
+
# Math: N=8, 64 accum steps → 8 HTM calls (10.6ms each) + 56 fast
|
| 423 |
+
# calls (4ms each). Total = 84.8 + 224 = 309ms → 106k tps.
|
| 424 |
+
#
|
| 425 |
+
# HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
|
| 426 |
+
_htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
|
| 427 |
+
if not hasattr(self, '_htm_call_idx'):
|
| 428 |
+
self._htm_call_idx = 0
|
| 429 |
+
|
| 430 |
+
_run_htm = (self._htm_call_idx % _htm_sub == 0)
|
| 431 |
+
self._htm_call_idx += 1
|
| 432 |
+
|
| 433 |
+
if _run_htm:
|
| 434 |
+
htm_handle = self.htm.forward_async(sdr_binary)
|
| 435 |
+
else:
|
| 436 |
+
htm_handle = None
|
| 437 |
+
|
| 438 |
+
if _profile: _t_htm_async = _ev()
|
| 439 |
+
|
| 440 |
+
dense_emb = self.wte(idx) # (B, T, d_model) bf16
|
| 441 |
+
|
| 442 |
+
if _profile: _t_wte = _ev()
|
| 443 |
+
|
| 444 |
+
if _run_htm:
|
| 445 |
+
htm_out = self.htm.forward_await(htm_handle)
|
| 446 |
+
self._htm_cache = htm_out.detach() # cache for non-HTM micro-batches
|
| 447 |
+
elif hasattr(self, '_htm_cache') and self._htm_cache is not None \
|
| 448 |
+
and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
|
| 449 |
+
htm_out = self._htm_cache
|
| 450 |
+
else:
|
| 451 |
+
# Very first call with subsample > 1: run HTM anyway.
|
| 452 |
+
htm_handle = self.htm.forward_async(sdr_binary)
|
| 453 |
+
htm_out = self.htm.forward_await(htm_handle)
|
| 454 |
+
self._htm_cache = htm_out.detach()
|
| 455 |
+
|
| 456 |
+
if _profile: _t_htm_await = _ev()
|
| 457 |
+
with torch.no_grad():
|
| 458 |
+
sdr_active_bits = float(self.sdr_semantic.target_active)
|
| 459 |
+
htm_anomaly = htm_out[..., -1].mean()
|
| 460 |
+
|
| 461 |
+
# Gradient bridge: HTM columns+anomaly -> d_model.
|
| 462 |
+
htm_proj_out = self.htm_proj(htm_out.to(dense_emb.dtype))
|
| 463 |
+
x = dense_emb + htm_proj_out
|
| 464 |
+
x = norm(x)
|
| 465 |
+
|
| 466 |
+
if _profile: _t_htm_proj = _ev()
|
| 467 |
+
|
| 468 |
+
# mHC-routed Mamba-3 stack with Engram injection at configured layer.
|
| 469 |
+
streams = self.mhc[0].init_streams(x)
|
| 470 |
+
_engram_ev = None
|
| 471 |
+
|
| 472 |
+
# Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
|
| 473 |
+
# measure residual contribution of each layer: delta_N = h_post - h_pre.
|
| 474 |
+
# All reads are detached no-grad to avoid autograd graph pollution.
|
| 475 |
+
_diag = self._diag_enabled
|
| 476 |
+
if _diag:
|
| 477 |
+
# Cast to float32 for the diagnostic arithmetic: the layer's
|
| 478 |
+
# residual contribution is small (~0.5 × rms-normed block output),
|
| 479 |
+
# which underflows in bf16 subtraction (3-digit mantissa) and
|
| 480 |
+
# reports delta_ratio=0 at the boundaries. float32 snapshot is
|
| 481 |
+
# ~3.8 MB extra memory per diag sample (B=1, T=2048, d=96) —
|
| 482 |
+
# negligible vs peak VRAM.
|
| 483 |
+
with torch.no_grad():
|
| 484 |
+
h_pre = self.mhc[0].merge_streams(streams).detach().float()
|
| 485 |
+
_run_svd = (self._diag_step % self._diag_svd_every) == 0
|
| 486 |
+
|
| 487 |
+
for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
|
| 488 |
+
def _block_fn(h, _block=block):
|
| 489 |
+
return self.drop(_block(norm(h)))
|
| 490 |
+
|
| 491 |
+
streams = mhc_layer(streams, _block_fn)
|
| 492 |
+
|
| 493 |
+
if i == self.engram_layer_idx:
|
| 494 |
+
if _profile: _t_pre_engram = _ev()
|
| 495 |
+
x_mid = mhc_layer.merge_streams(streams)
|
| 496 |
+
x_mid, hit_rate = self.engram(x_mid, idx)
|
| 497 |
+
streams = mhc_layer.init_streams(x_mid)
|
| 498 |
+
self._metrics['engram_hit_rate'] = hit_rate
|
| 499 |
+
if _profile: _engram_ev = _ev()
|
| 500 |
+
|
| 501 |
+
if _diag:
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
h_post = mhc_layer.merge_streams(streams).detach().float()
|
| 504 |
+
in_n = h_pre.pow(2).mean().sqrt()
|
| 505 |
+
out_n = h_post.pow(2).mean().sqrt()
|
| 506 |
+
d_n = (h_post - h_pre).pow(2).mean().sqrt()
|
| 507 |
+
self._metrics[f'layer_{i}_in_norm'] = float(in_n.item())
|
| 508 |
+
self._metrics[f'layer_{i}_out_norm'] = float(out_n.item())
|
| 509 |
+
self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item())
|
| 510 |
+
self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item())
|
| 511 |
+
if _run_svd:
|
| 512 |
+
# Effective rank via participation ratio of singular values.
|
| 513 |
+
# eff_rank = (Σσ)^2 / Σσ² — smooth rank proxy, bounded by d_model.
|
| 514 |
+
# Sampled to keep overhead low (SVD is O(min(B*T, D)^2·D)).
|
| 515 |
+
flat = h_post.reshape(-1, h_post.shape[-1])[:512].float()
|
| 516 |
+
try:
|
| 517 |
+
s = torch.linalg.svdvals(flat)
|
| 518 |
+
eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item())
|
| 519 |
+
self._metrics[f'layer_{i}_eff_rank'] = eff_rank
|
| 520 |
+
except Exception:
|
| 521 |
+
pass
|
| 522 |
+
h_pre = h_post
|
| 523 |
+
|
| 524 |
+
if _diag:
|
| 525 |
+
self._diag_step += 1
|
| 526 |
+
|
| 527 |
+
if _profile: _t_blocks = _ev()
|
| 528 |
+
|
| 529 |
+
self._metrics['sdr_active_bits'] = sdr_active_bits
|
| 530 |
+
self._metrics['htm_anomaly'] = htm_anomaly
|
| 531 |
+
|
| 532 |
+
x = self.mhc[-1].merge_streams(streams)
|
| 533 |
+
x = norm(x)
|
| 534 |
+
|
| 535 |
+
if _profile: _t_merge = _ev()
|
| 536 |
+
|
| 537 |
+
softcap = self.softcap
|
| 538 |
+
_softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1"
|
| 539 |
+
if targets is not None:
|
| 540 |
+
smoothing = self.config.label_smoothing
|
| 541 |
+
V = self.config.vocab_size
|
| 542 |
+
|
| 543 |
+
# Sampled softmax: instead of computing logits for ALL V tokens,
|
| 544 |
+
# compute only for the target + K random negatives. Reduces the
|
| 545 |
+
# lm_head matmul from (B*T, d) × (d, V) to (B*T, d) × (d, K+1).
|
| 546 |
+
# At V=65536 and K=4096: 16× less compute, ~4× tps improvement.
|
| 547 |
+
# The log-sum-exp correction adjusts for the sampling bias.
|
| 548 |
+
# Set HYDRA_SAMPLED_SOFTMAX=0 to disable (full softmax).
|
| 549 |
+
K_neg = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096"))
|
| 550 |
+
use_sampled = K_neg > 0 and K_neg < V and self.training
|
| 551 |
+
|
| 552 |
+
if use_sampled:
|
| 553 |
+
# Flatten hidden states + targets
|
| 554 |
+
h_flat = x.reshape(-1, x.shape[-1]) # (B*T, d)
|
| 555 |
+
t_flat = targets.reshape(-1) # (B*T,)
|
| 556 |
+
n = h_flat.shape[0]
|
| 557 |
+
|
| 558 |
+
# Sample K negatives uniformly from [0, V)
|
| 559 |
+
neg_ids = torch.randint(0, V, (K_neg,), device=x.device)
|
| 560 |
+
# Gather lm_head weights for target + negatives
|
| 561 |
+
all_ids = torch.cat([t_flat, neg_ids]) # (B*T + K,)
|
| 562 |
+
sampled_w = self.lm_head.weight[all_ids] # (B*T + K, d)
|
| 563 |
+
|
| 564 |
+
# Compute sampled logits: for each position, dot with its
|
| 565 |
+
# target weight and all K negative weights.
|
| 566 |
+
# Target logit: dot product of h[i] with w[target[i]]
|
| 567 |
+
target_w = sampled_w[:n] # (B*T, d)
|
| 568 |
+
neg_w = sampled_w[n:] # (K, d)
|
| 569 |
+
target_logit = (h_flat * target_w).sum(-1) # (B*T,)
|
| 570 |
+
neg_logits = h_flat @ neg_w.t() # (B*T, K)
|
| 571 |
+
|
| 572 |
+
if not _softcap_clamp:
|
| 573 |
+
target_logit = softcap * torch.tanh(target_logit / softcap)
|
| 574 |
+
neg_logits = softcap * torch.tanh(neg_logits / softcap)
|
| 575 |
+
|
| 576 |
+
# Sampled softmax loss: -log(exp(target) / (exp(target) + sum(exp(neg))))
|
| 577 |
+
# With log-sum-exp correction for sampling K of V negatives.
|
| 578 |
+
# Correction: add log(V/K) to negative logits to account for
|
| 579 |
+
# the fact that we're only seeing K of V possible negatives.
|
| 580 |
+
log_correction = torch.tensor(V / K_neg, device=x.device).log()
|
| 581 |
+
all_logits = torch.cat([
|
| 582 |
+
target_logit.unsqueeze(-1), # (B*T, 1)
|
| 583 |
+
neg_logits + log_correction, # (B*T, K)
|
| 584 |
+
], dim=-1).float() # (B*T, K+1)
|
| 585 |
+
|
| 586 |
+
# CE with target always at index 0
|
| 587 |
+
ce_targets = torch.zeros(n, dtype=torch.long, device=x.device)
|
| 588 |
+
if reduction == 'none':
|
| 589 |
+
return F.cross_entropy(all_logits, ce_targets, reduction='none')
|
| 590 |
+
out = F.cross_entropy(all_logits, ce_targets, reduction='mean',
|
| 591 |
+
label_smoothing=smoothing)
|
| 592 |
+
else:
|
| 593 |
+
# Full softmax path (eval or HYDRA_SAMPLED_SOFTMAX=0)
|
| 594 |
+
chunk_size = int(os.environ.get("HYDRA_CE_CHUNK", "1024"))
|
| 595 |
+
if chunk_size <= 0:
|
| 596 |
+
MAX_LOGITS_BYTES = 256 * 1024 * 1024
|
| 597 |
+
tokens_per_chunk = max(V, MAX_LOGITS_BYTES // (V * 4))
|
| 598 |
+
chunk_size = max(1, tokens_per_chunk // max(1, B))
|
| 599 |
+
chunk_size = min(chunk_size, T)
|
| 600 |
+
|
| 601 |
+
if reduction == 'none':
|
| 602 |
+
loss_parts = []
|
| 603 |
+
for start in range(0, T, chunk_size):
|
| 604 |
+
end = min(start + chunk_size, T)
|
| 605 |
+
chunk_logits = self.lm_head(x[:, start:end, :]).float()
|
| 606 |
+
if _softcap_clamp:
|
| 607 |
+
chunk_logits = torch.clamp(chunk_logits, -softcap, softcap)
|
| 608 |
+
else:
|
| 609 |
+
chunk_logits = softcap * torch.tanh(chunk_logits / softcap)
|
| 610 |
+
chunk_targets = targets[:, start:end].reshape(-1)
|
| 611 |
+
chunk_loss = F.cross_entropy(
|
| 612 |
+
chunk_logits.view(-1, chunk_logits.size(-1)),
|
| 613 |
+
chunk_targets, ignore_index=-1, reduction='none',
|
| 614 |
+
)
|
| 615 |
+
loss_parts.append(chunk_loss)
|
| 616 |
+
return torch.cat(loss_parts)
|
| 617 |
+
|
| 618 |
+
total_loss = 0.0
|
| 619 |
+
total_tokens = 0
|
| 620 |
+
for start in range(0, T, chunk_size):
|
| 621 |
+
end = min(start + chunk_size, T)
|
| 622 |
+
chunk_logits = self.lm_head(x[:, start:end, :]).float()
|
| 623 |
+
if _softcap_clamp:
|
| 624 |
+
chunk_logits = torch.clamp(chunk_logits, -softcap, softcap)
|
| 625 |
+
else:
|
| 626 |
+
chunk_logits = softcap * torch.tanh(chunk_logits / softcap)
|
| 627 |
+
chunk_targets = targets[:, start:end].reshape(-1)
|
| 628 |
+
chunk_loss = F.cross_entropy(
|
| 629 |
+
chunk_logits.view(-1, chunk_logits.size(-1)),
|
| 630 |
+
chunk_targets, ignore_index=-1, reduction='sum',
|
| 631 |
+
label_smoothing=smoothing,
|
| 632 |
+
)
|
| 633 |
+
total_loss = total_loss + chunk_loss
|
| 634 |
+
total_tokens += (chunk_targets != -1).sum()
|
| 635 |
+
out = total_loss / total_tokens
|
| 636 |
+
if _profile:
|
| 637 |
+
_t_end = _ev()
|
| 638 |
+
torch.cuda.synchronize()
|
| 639 |
+
def _ms(a, b): return a.elapsed_time(b)
|
| 640 |
+
print(
|
| 641 |
+
f"[PROFILE B={B} T={T}] "
|
| 642 |
+
f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
|
| 643 |
+
f"wte={_ms(_t_htm_async, _t_wte):.2f} "
|
| 644 |
+
f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
|
| 645 |
+
f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} "
|
| 646 |
+
f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} "
|
| 647 |
+
f"merge={_ms(_t_blocks, _t_merge):.2f} "
|
| 648 |
+
f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
|
| 649 |
+
f"total={_ms(_t0, _t_end):.2f} ms",
|
| 650 |
+
flush=True,
|
| 651 |
+
)
|
| 652 |
+
return out
|
| 653 |
+
|
| 654 |
+
logits = self.lm_head(x).float()
|
| 655 |
+
if _softcap_clamp:
|
| 656 |
+
logits = torch.clamp(logits, -softcap, softcap)
|
| 657 |
+
else:
|
| 658 |
+
logits = softcap * torch.tanh(logits / softcap)
|
| 659 |
+
return logits
|
overlay/hydra/optimizer.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
|
| 2 |
+
|
| 3 |
+
Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
|
| 4 |
+
|
| 5 |
+
F1-F15 state preserved:
|
| 6 |
+
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
|
| 7 |
+
step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
|
| 8 |
+
fresh. Persistent copies of param storage would be mutated between forward
|
| 9 |
+
passes (via lerp_/sub_ on stacked tensors that share storage with params),
|
| 10 |
+
triggering "modified in-place" errors on grad_accum=2 backwards.
|
| 11 |
+
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
|
| 12 |
+
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
|
| 13 |
+
dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
|
| 14 |
+
stream-capture deadlock. See .omc/muon_compile_bug.md for the full
|
| 15 |
+
investigation.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
|
| 25 |
+
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
|
| 26 |
+
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
polar_express_coeffs = [
|
| 30 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 31 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 32 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 33 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 34 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
|
| 39 |
+
# Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
|
| 40 |
+
# for the whole group) driven from MuonAdamW._step_adamw below.
|
| 41 |
+
grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
|
| 42 |
+
p.mul_(1 - lr_t * wd_t)
|
| 43 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 44 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 45 |
+
bias1 = 1 - beta1_t ** step_t
|
| 46 |
+
bias2 = 1 - beta2_t ** step_t
|
| 47 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 48 |
+
step_size = lr_t / bias1
|
| 49 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# F15 muon_step_fused compile strategy.
|
| 54 |
+
#
|
| 55 |
+
# HYDRA_MUON_COMPILE env gate:
|
| 56 |
+
# "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
|
| 57 |
+
# Dynamic=True collapses the per-shape specialization cache so that N
|
| 58 |
+
# Muon param-groups with N distinct shapes trigger 1 compile, not N.
|
| 59 |
+
# mode="default" keeps the inductor codegen but disables cudagraphs,
|
| 60 |
+
# which is what caused the step-17→18 silent deadlock observed under
|
| 61 |
+
# the original dynamic=False configuration: cudagraph stream capture
|
| 62 |
+
# can deadlock against HTM's CUDA kernels running on the default
|
| 63 |
+
# stream, and the failure mode at capture-time is a silent hang
|
| 64 |
+
# (100% GPU util, no log output, process state R).
|
| 65 |
+
# "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
|
| 66 |
+
# Keeps an escape hatch in case a future torch/inductor regression
|
| 67 |
+
# reintroduces a deadlock.
|
| 68 |
+
#
|
| 69 |
+
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
|
| 70 |
+
# alias-analysis edge case where inductor sees `g is stacked_grads` and
|
| 71 |
+
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
|
| 74 |
+
|
| 75 |
+
def _maybe_compile(fn):
|
| 76 |
+
if _MUON_COMPILE:
|
| 77 |
+
# mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
|
| 78 |
+
# would enable) to avoid stream-capture deadlocks against HTM's CUDA
|
| 79 |
+
# kernels. dynamic=True minimizes recompile count across param-group
|
| 80 |
+
# shapes.
|
| 81 |
+
return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
|
| 82 |
+
return fn
|
| 83 |
+
|
| 84 |
+
@_maybe_compile
|
| 85 |
+
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
|
| 86 |
+
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
|
| 87 |
+
# Cast grads to param dtype AND clone defensively to break any alias
|
| 88 |
+
# between the (freshly-stacked) input and the in-place lerp_ below.
|
| 89 |
+
# Without this, inductor's alias analysis can emit code that reads from
|
| 90 |
+
# post-mutation storage when computing `v_mean = g.square().mean(...)`.
|
| 91 |
+
stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
|
| 92 |
+
# Nesterov momentum
|
| 93 |
+
momentum = momentum_t.to(stacked_grads.dtype)
|
| 94 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 95 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 96 |
+
# Polar express orthogonalization
|
| 97 |
+
X = g.bfloat16()
|
| 98 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
| 99 |
+
if g.size(-2) > g.size(-1):
|
| 100 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 101 |
+
A = X.mT @ X
|
| 102 |
+
B = b * A + c * (A @ A)
|
| 103 |
+
X = a * X + X @ B
|
| 104 |
+
else:
|
| 105 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 106 |
+
A = X @ X.mT
|
| 107 |
+
B = b * A + c * (A @ A)
|
| 108 |
+
X = a * X + B @ X
|
| 109 |
+
g = X
|
| 110 |
+
# NorMuon variance reduction
|
| 111 |
+
# Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
|
| 112 |
+
# float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
|
| 113 |
+
beta2 = beta2_t.to(second_momentum_buffer.dtype)
|
| 114 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 115 |
+
red_dim_size = g.size(red_dim)
|
| 116 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 117 |
+
v_norm = v_norm_sq.sqrt()
|
| 118 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 119 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 120 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 121 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 122 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 123 |
+
g = g * final_scale.to(g.dtype)
|
| 124 |
+
# Cautious weight decay + parameter update
|
| 125 |
+
lr = lr_t.to(g.dtype)
|
| 126 |
+
wd = wd_t.to(g.dtype)
|
| 127 |
+
mask = (g * stacked_params) >= 0
|
| 128 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class MuonAdamW(torch.optim.Optimizer):
|
| 132 |
+
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
|
| 133 |
+
|
| 134 |
+
def __init__(self, param_groups):
|
| 135 |
+
super().__init__(param_groups, defaults={})
|
| 136 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 137 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 138 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 139 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 140 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 141 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 142 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 143 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 144 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 145 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 146 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 147 |
+
|
| 148 |
+
def _step_adamw(self, group):
|
| 149 |
+
params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
|
| 150 |
+
for p in group['params']:
|
| 151 |
+
if p.grad is None:
|
| 152 |
+
continue
|
| 153 |
+
state = self.state[p]
|
| 154 |
+
if not state:
|
| 155 |
+
state['step'] = 0
|
| 156 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 157 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 158 |
+
if 'step_t' not in state:
|
| 159 |
+
# _fused_adamw_ wants a per-param float step tensor on-device.
|
| 160 |
+
state['step_t'] = torch.tensor(
|
| 161 |
+
float(state['step']), dtype=torch.float32, device=p.device
|
| 162 |
+
)
|
| 163 |
+
state['step'] += 1
|
| 164 |
+
params.append(p)
|
| 165 |
+
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
|
| 166 |
+
exp_avgs.append(state['exp_avg'])
|
| 167 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 168 |
+
state_steps.append(state['step_t'])
|
| 169 |
+
|
| 170 |
+
if not params:
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
|
| 174 |
+
# _fused_adamw_ needs uniform (device, dtype) within a call, so
|
| 175 |
+
# group by (device, dtype) — same pattern as PyTorch's own
|
| 176 |
+
# AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
|
| 177 |
+
buckets = {}
|
| 178 |
+
for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
|
| 179 |
+
key = (p.device, p.dtype)
|
| 180 |
+
buckets.setdefault(key, ([], [], [], [], []))
|
| 181 |
+
b_p, b_g, b_ea, b_es, b_st = buckets[key]
|
| 182 |
+
b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
|
| 183 |
+
|
| 184 |
+
lr_f = float(group['lr'])
|
| 185 |
+
b1_f = float(group['betas'][0])
|
| 186 |
+
b2_f = float(group['betas'][1])
|
| 187 |
+
wd_f = float(group['weight_decay'])
|
| 188 |
+
eps_f = float(group['eps'])
|
| 189 |
+
for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
|
| 190 |
+
torch._foreach_add_(b_st, 1.0)
|
| 191 |
+
torch._fused_adamw_(
|
| 192 |
+
b_p, b_g, b_ea, b_es,
|
| 193 |
+
[], # max_exp_avg_sqs unused (amsgrad=False)
|
| 194 |
+
b_st,
|
| 195 |
+
amsgrad=False,
|
| 196 |
+
lr=lr_f, beta1=b1_f, beta2=b2_f,
|
| 197 |
+
weight_decay=wd_f, eps=eps_f,
|
| 198 |
+
maximize=False,
|
| 199 |
+
grad_scale=None, found_inf=None,
|
| 200 |
+
)
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
# Fallback per-param path.
|
| 204 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 205 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 206 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 207 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 208 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 209 |
+
for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
|
| 210 |
+
self._adamw_step_t.fill_(self.state[p]['step'])
|
| 211 |
+
adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
|
| 212 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 213 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
| 214 |
+
|
| 215 |
+
def _step_muon(self, group):
|
| 216 |
+
params = [p for p in group['params'] if p.grad is not None]
|
| 217 |
+
if not params:
|
| 218 |
+
return
|
| 219 |
+
p = params[0]
|
| 220 |
+
state = self.state[p]
|
| 221 |
+
num_params = len(params)
|
| 222 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 223 |
+
if "momentum_buffer" not in state:
|
| 224 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 225 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 226 |
+
if "second_momentum_buffer" not in state:
|
| 227 |
+
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
|
| 228 |
+
full_shape = (num_params, *shape)
|
| 229 |
+
state_shape = list(full_shape)
|
| 230 |
+
state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
|
| 231 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 232 |
+
# F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
|
| 233 |
+
# This was the autograd-safety fix that unblocks grad_accum>=2.
|
| 234 |
+
stacked_grads = torch.stack([p.grad for p in params])
|
| 235 |
+
stacked_params = torch.stack(params)
|
| 236 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 237 |
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 238 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
|
| 239 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 240 |
+
muon_step_fused(stacked_grads, stacked_params,
|
| 241 |
+
state["momentum_buffer"], state["second_momentum_buffer"],
|
| 242 |
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
|
| 243 |
+
self._muon_beta2_t, group["ns_steps"], red_dim)
|
| 244 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 245 |
+
|
| 246 |
+
@torch.no_grad()
|
| 247 |
+
def step(self):
|
| 248 |
+
for group in self.param_groups:
|
| 249 |
+
if group['kind'] == 'adamw':
|
| 250 |
+
self._step_adamw(group)
|
| 251 |
+
elif group['kind'] == 'muon':
|
| 252 |
+
self._step_muon(group)
|
overlay/subsystems/htm.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HTM torch wrapper around the pyo3 ``htm_rust`` crate.
|
| 3 |
+
|
| 4 |
+
Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to
|
| 5 |
+
``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream
|
| 6 |
+
and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly
|
| 7 |
+
score. HTM learning is Hebbian (not gradient), so the wrapper runs under
|
| 8 |
+
``torch.no_grad()``. Downstream layers carry gradients back to the embedding
|
| 9 |
+
via their own learnable projection from the binary column output.
|
| 10 |
+
|
| 11 |
+
Per-sequence state semantics
|
| 12 |
+
---------------------------
|
| 13 |
+
Training-time forward passes are independent windows of tokens (re-sampled
|
| 14 |
+
every step), so carrying TM state across calls would mix unrelated contexts.
|
| 15 |
+
This layer calls ``reset()`` on every region at the top of ``forward``; the
|
| 16 |
+
TM learns within-window temporal patterns only. Users that want cross-window
|
| 17 |
+
continuity (e.g. eval over a long document) should instead construct the
|
| 18 |
+
layer and drive ``step_stream`` themselves (not implemented here; the
|
| 19 |
+
single-forward contract is sufficient for the autoresearch loop).
|
| 20 |
+
|
| 21 |
+
Device handling
|
| 22 |
+
---------------
|
| 23 |
+
``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a
|
| 24 |
+
``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back
|
| 25 |
+
to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this
|
| 26 |
+
copy is small compared to the SP/TM compute.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import time
|
| 32 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
|
| 38 |
+
import htm_rust
|
| 39 |
+
|
| 40 |
+
# step_many releases the GIL for the whole pass, so multiple threads can
|
| 41 |
+
# truly run regions in parallel — wall-clock scales with B up to CPU cores.
|
| 42 |
+
_HTM_HAS_STEP_MANY = hasattr(htm_rust.HTMRegion, "step_many")
|
| 43 |
+
# GPU backend: built with `maturin develop --features gpu`. One CUDA region
|
| 44 |
+
# per batch slot, persistent device state for SP synapses. Transparent
|
| 45 |
+
# fallback to CPU when not available.
|
| 46 |
+
_HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu")
|
| 47 |
+
# Zero-copy CUDA path: consumes torch CUDA tensors directly via the
|
| 48 |
+
# __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
|
| 49 |
+
# and the D2H of outputs. Huge win when the input SDR already lives on GPU
|
| 50 |
+
# (which is the train.py hot path — retina is a device buffer).
|
| 51 |
+
_HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_cuda")
|
| 52 |
+
# Fused megakernel path: collapses all T timesteps + SP + TM into a single
|
| 53 |
+
# CUDA launch per forward. Replaces global top-K with per-column threshold
|
| 54 |
+
# inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
|
| 55 |
+
# Opt-in via env var (default on when available).
|
| 56 |
+
import os as _os_fused
|
| 57 |
+
_HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_fused_cuda")
|
| 58 |
+
_HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class HTMLayer(nn.Module):
|
| 62 |
+
"""Batched torch wrapper around ``htm_rust.HTMRegion``.
|
| 63 |
+
|
| 64 |
+
One independent region per batch slot so temporal memory learns
|
| 65 |
+
sequence-local patterns without cross-batch bleed. Regions grow
|
| 66 |
+
lazily if a larger batch shows up.
|
| 67 |
+
|
| 68 |
+
Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are
|
| 69 |
+
the binary active-column mask (float32 0/1) and the last channel is
|
| 70 |
+
the per-timestep anomaly score in [0, 1].
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
input_bits: int = 16384,
|
| 76 |
+
n_columns: int = 2048,
|
| 77 |
+
cells_per_column: int = 32,
|
| 78 |
+
batch_size: int = 1,
|
| 79 |
+
seed: int = 42,
|
| 80 |
+
learn: bool = True,
|
| 81 |
+
reset_each_forward: bool = True,
|
| 82 |
+
use_gpu: bool | None = None,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.input_bits = input_bits
|
| 86 |
+
self.n_columns = n_columns
|
| 87 |
+
self.cells_per_column = cells_per_column
|
| 88 |
+
self.learn = learn
|
| 89 |
+
self.reset_each_forward = reset_each_forward
|
| 90 |
+
self._seed_base = seed
|
| 91 |
+
# Learn gating: HTM learn kernels (tm_punish, tm_learn_reinforce, tm_grow)
|
| 92 |
+
# are 56% of total HTM CUDA time. Gating them to run every N forwards
|
| 93 |
+
# instead of every forward cuts HTM cost ~2x. Hebbian learning still
|
| 94 |
+
# converges since the EMA accumulates over many calls. Env:
|
| 95 |
+
# HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
|
| 96 |
+
import os as _os
|
| 97 |
+
self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1")))
|
| 98 |
+
self._forward_counter = 0
|
| 99 |
+
# GPU backend gate. Default: auto-detect — use GPU when the pyo3
|
| 100 |
+
# module was built with --features gpu AND CUDA is actually usable.
|
| 101 |
+
if use_gpu is None:
|
| 102 |
+
use_gpu = _HTM_HAS_GPU and torch.cuda.is_available()
|
| 103 |
+
elif use_gpu and not _HTM_HAS_GPU:
|
| 104 |
+
raise RuntimeError(
|
| 105 |
+
"HTMLayer(use_gpu=True) but htm_rust was not built with "
|
| 106 |
+
"--features gpu. Re-run `maturin develop --features gpu`."
|
| 107 |
+
)
|
| 108 |
+
self._use_gpu = bool(use_gpu)
|
| 109 |
+
cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion
|
| 110 |
+
self._region_cls = cls
|
| 111 |
+
self._regions = [
|
| 112 |
+
cls(input_bits, n_columns, cells_per_column, seed + i)
|
| 113 |
+
for i in range(batch_size)
|
| 114 |
+
]
|
| 115 |
+
self.register_buffer("_dummy", torch.zeros(1), persistent=False)
|
| 116 |
+
import os as _os
|
| 117 |
+
self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16))
|
| 118 |
+
|
| 119 |
+
def _ensure_regions(self, B: int) -> None:
|
| 120 |
+
while len(self._regions) < B:
|
| 121 |
+
idx = len(self._regions)
|
| 122 |
+
self._regions.append(
|
| 123 |
+
self._region_cls(
|
| 124 |
+
self.input_bits,
|
| 125 |
+
self.n_columns,
|
| 126 |
+
self.cells_per_column,
|
| 127 |
+
self._seed_base + idx,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def reset(self) -> None:
|
| 132 |
+
"""Clear TM predictive state on every region (keeps SP synapses)."""
|
| 133 |
+
for r in self._regions:
|
| 134 |
+
r.reset()
|
| 135 |
+
|
| 136 |
+
@torch.no_grad()
|
| 137 |
+
def forward(self, sdr: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
B, T, D = sdr.shape
|
| 139 |
+
if D != self.input_bits:
|
| 140 |
+
raise ValueError(f"expected input_bits={self.input_bits}, got {D}")
|
| 141 |
+
self._ensure_regions(B)
|
| 142 |
+
if self.reset_each_forward:
|
| 143 |
+
self.reset()
|
| 144 |
+
|
| 145 |
+
# Learn-gate: run learn kernels only every N forwards (skips 56% of
|
| 146 |
+
# HTM CUDA time on skip-forwards; Hebbian EMA still converges).
|
| 147 |
+
self._forward_counter += 1
|
| 148 |
+
learn = bool(
|
| 149 |
+
self.learn
|
| 150 |
+
and self.training
|
| 151 |
+
and (self._forward_counter % self._learn_every == 0)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
|
| 155 |
+
# so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
|
| 156 |
+
# kernel writes directly into torch-owned CUDA tensors via CAI.
|
| 157 |
+
# Gives 5-10x tok/s on train.py vs the numpy path below.
|
| 158 |
+
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
|
| 159 |
+
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
|
| 160 |
+
cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
|
| 161 |
+
anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
|
| 162 |
+
# Pick fused (1 launch) or legacy (12*T launches) path.
|
| 163 |
+
if _HTM_USE_FUSED:
|
| 164 |
+
for b in range(B):
|
| 165 |
+
self._regions[b].step_many_fused_cuda(
|
| 166 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 167 |
+
cols_out[b].__cuda_array_interface__,
|
| 168 |
+
anom_out[b].__cuda_array_interface__,
|
| 169 |
+
learn,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
for b in range(B):
|
| 173 |
+
self._regions[b].step_many_cuda(
|
| 174 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 175 |
+
cols_out[b].__cuda_array_interface__,
|
| 176 |
+
anom_out[b].__cuda_array_interface__,
|
| 177 |
+
learn,
|
| 178 |
+
)
|
| 179 |
+
# Assemble (B, T, n_cols+1) — keep bf16-friendly float32.
|
| 180 |
+
return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1)
|
| 181 |
+
|
| 182 |
+
# Fallback: CPU / numpy path. Kept for CPU-input case and for
|
| 183 |
+
# builds without CAI support.
|
| 184 |
+
sdr_np = sdr.detach().cpu().contiguous().numpy().view(np.bool_)
|
| 185 |
+
out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32)
|
| 186 |
+
|
| 187 |
+
def _process_one(b: int) -> None:
|
| 188 |
+
region = self._regions[b]
|
| 189 |
+
if self._use_gpu:
|
| 190 |
+
cols, anom = region.step_many_gpu(sdr_np[b], learn)
|
| 191 |
+
out[b, :, : self.n_columns] = cols
|
| 192 |
+
out[b, :, self.n_columns] = anom
|
| 193 |
+
elif _HTM_HAS_STEP_MANY:
|
| 194 |
+
# Single Rust call: T steps with GIL released for the whole pass.
|
| 195 |
+
cols, anom = region.step_many(sdr_np[b], learn) # cols (T, n_cols), anom (T,)
|
| 196 |
+
out[b, :, : self.n_columns] = cols
|
| 197 |
+
out[b, :, self.n_columns] = anom
|
| 198 |
+
else:
|
| 199 |
+
for t in range(T):
|
| 200 |
+
active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn)
|
| 201 |
+
out[b, t, : self.n_columns] = active_cols
|
| 202 |
+
out[b, t, self.n_columns] = float(anomaly)
|
| 203 |
+
|
| 204 |
+
if B == 1:
|
| 205 |
+
_process_one(0)
|
| 206 |
+
elif self._use_gpu:
|
| 207 |
+
# GPU regions share the CUDA context; serialise to avoid contention
|
| 208 |
+
# for stream 0. Per-region latency is dominated by kernel compute,
|
| 209 |
+
# not threadable on a single stream cheaply — future work: one
|
| 210 |
+
# CUDA stream per region.
|
| 211 |
+
for b in range(B):
|
| 212 |
+
_process_one(b)
|
| 213 |
+
else:
|
| 214 |
+
# Each thread runs in pure Rust under py.allow_threads, so they
|
| 215 |
+
# parallelise to wall-clock min(B, CPU_cores).
|
| 216 |
+
list(self._htm_pool.map(_process_one, range(B)))
|
| 217 |
+
|
| 218 |
+
return torch.from_numpy(out).to(sdr.device)
|
| 219 |
+
|
| 220 |
+
def forward_async(self, sdr: torch.Tensor):
|
| 221 |
+
"""Submit HTM work and return a handle awaitable via ``forward_await``.
|
| 222 |
+
|
| 223 |
+
On the CAI zero-copy path (GPU tensor in, GPU region), the Rust
|
| 224 |
+
CUDA kernels are launched on cudarc's internal stream and control
|
| 225 |
+
returns **immediately** — no device synchronization. The caller's
|
| 226 |
+
next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued
|
| 227 |
+
on PyTorch's default stream and can execute while HTM kernels run
|
| 228 |
+
on the cudarc stream. ``forward_await`` performs the cross-stream
|
| 229 |
+
sync (via ``device_sync``) and assembles the output tensor only
|
| 230 |
+
when the result is actually consumed.
|
| 231 |
+
|
| 232 |
+
For cooperative kernels (``step_many_fused_cuda``) the GPU can only
|
| 233 |
+
run one cooperative launch at a time, so kernel-level overlap with
|
| 234 |
+
default-stream work is limited. The win is **CPU-side launch
|
| 235 |
+
overlap**: instead of the CPU blocking ~10 ms waiting for HTM
|
| 236 |
+
before it can even enqueue wte/mamba, it enqueues everything up
|
| 237 |
+
front and the GPU executes back-to-back without CPU stalls.
|
| 238 |
+
|
| 239 |
+
On the legacy CPU/numpy path, work is dispatched to a thread pool
|
| 240 |
+
as before."""
|
| 241 |
+
B, T, D = sdr.shape
|
| 242 |
+
if D != self.input_bits:
|
| 243 |
+
raise ValueError(f"expected input_bits={self.input_bits}, got {D}")
|
| 244 |
+
self._ensure_regions(B)
|
| 245 |
+
if self.reset_each_forward:
|
| 246 |
+
self.reset()
|
| 247 |
+
learn = bool(self.learn and self.training)
|
| 248 |
+
|
| 249 |
+
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
|
| 250 |
+
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
|
| 251 |
+
cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
|
| 252 |
+
anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
|
| 253 |
+
# ONE cooperative kernel launch for all B regions. Breaks past
|
| 254 |
+
# the CUDA cooperative-kernel device-level serialization (only
|
| 255 |
+
# one cooperative kernel runs at a time). A single launch with
|
| 256 |
+
# grid.y = B processes all regions concurrently — ~B× speedup.
|
| 257 |
+
# Falls back to sequential dispatch if the batched entry isn't
|
| 258 |
+
# available (older htm_rust wheel).
|
| 259 |
+
if _HTM_USE_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"):
|
| 260 |
+
# Slice self._regions to match B: _ensure_regions may have
|
| 261 |
+
# allocated more regions than the current batch size needs
|
| 262 |
+
# (e.g. factual eval uses smaller batches than training).
|
| 263 |
+
try:
|
| 264 |
+
htm_rust.step_batch_fused_cuda(
|
| 265 |
+
self._regions[:B],
|
| 266 |
+
[sdr_u8[b].__cuda_array_interface__ for b in range(B)],
|
| 267 |
+
[cols_out[b].__cuda_array_interface__ for b in range(B)],
|
| 268 |
+
[anom_out[b].__cuda_array_interface__ for b in range(B)],
|
| 269 |
+
learn,
|
| 270 |
+
)
|
| 271 |
+
except RuntimeError as _e:
|
| 272 |
+
if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
|
| 273 |
+
# Batch too large for cooperative grid. Fall back to
|
| 274 |
+
# sequential per-region fused launches (each B=1).
|
| 275 |
+
for b in range(B):
|
| 276 |
+
self._regions[b].step_many_fused_cuda(
|
| 277 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 278 |
+
cols_out[b].__cuda_array_interface__,
|
| 279 |
+
anom_out[b].__cuda_array_interface__,
|
| 280 |
+
learn,
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
raise
|
| 284 |
+
elif _HTM_USE_FUSED:
|
| 285 |
+
for b in range(B):
|
| 286 |
+
self._regions[b].step_many_fused_cuda(
|
| 287 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 288 |
+
cols_out[b].__cuda_array_interface__,
|
| 289 |
+
anom_out[b].__cuda_array_interface__,
|
| 290 |
+
learn,
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
for b in range(B):
|
| 294 |
+
self._regions[b].step_many_cuda(
|
| 295 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 296 |
+
cols_out[b].__cuda_array_interface__,
|
| 297 |
+
anom_out[b].__cuda_array_interface__,
|
| 298 |
+
learn,
|
| 299 |
+
)
|
| 300 |
+
# NO sync here — kernels are in-flight on cudarc's stream.
|
| 301 |
+
# forward_await() will sync before the output is consumed.
|
| 302 |
+
return {
|
| 303 |
+
'cuda_deferred': True,
|
| 304 |
+
'cols_out': cols_out,
|
| 305 |
+
'anom_out': anom_out,
|
| 306 |
+
'region0': self._regions[0],
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
sdr_np = sdr.detach().cpu().contiguous().numpy().view(np.bool_)
|
| 310 |
+
out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32)
|
| 311 |
+
|
| 312 |
+
def _process_one(b):
|
| 313 |
+
region = self._regions[b]
|
| 314 |
+
if self._use_gpu:
|
| 315 |
+
cols, anom = region.step_many_gpu(sdr_np[b], learn)
|
| 316 |
+
out[b, :, : self.n_columns] = cols
|
| 317 |
+
out[b, :, self.n_columns] = anom
|
| 318 |
+
elif _HTM_HAS_STEP_MANY:
|
| 319 |
+
cols, anom = region.step_many(sdr_np[b], learn)
|
| 320 |
+
out[b, :, : self.n_columns] = cols
|
| 321 |
+
out[b, :, self.n_columns] = anom
|
| 322 |
+
else:
|
| 323 |
+
for t in range(T):
|
| 324 |
+
active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn)
|
| 325 |
+
out[b, t, : self.n_columns] = active_cols
|
| 326 |
+
out[b, t, self.n_columns] = float(anomaly)
|
| 327 |
+
|
| 328 |
+
fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)])
|
| 329 |
+
return {'fut': fut, 'out': out, 'device': sdr.device}
|
| 330 |
+
|
| 331 |
+
def forward_await(self, handle) -> torch.Tensor:
|
| 332 |
+
if handle.get('cuda_deferred'):
|
| 333 |
+
# Cross-stream sync: block until cudarc stream finishes HTM
|
| 334 |
+
# kernels so the output tensors are safe to read on the
|
| 335 |
+
# default stream.
|
| 336 |
+
region0 = handle['region0']
|
| 337 |
+
if hasattr(region0, "device_sync"):
|
| 338 |
+
region0.device_sync()
|
| 339 |
+
else:
|
| 340 |
+
torch.cuda.synchronize()
|
| 341 |
+
cols_out = handle['cols_out']
|
| 342 |
+
anom_out = handle['anom_out']
|
| 343 |
+
return torch.cat(
|
| 344 |
+
(cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1
|
| 345 |
+
)
|
| 346 |
+
if 'cuda_result' in handle:
|
| 347 |
+
return handle['cuda_result']
|
| 348 |
+
handle['fut'].result()
|
| 349 |
+
return torch.from_numpy(handle['out']).to(handle['device'])
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == "__main__":
|
| 353 |
+
torch.manual_seed(0)
|
| 354 |
+
|
| 355 |
+
# Smoke test: (B=2, T=4, D=16384) random 2%-sparse SDR
|
| 356 |
+
B, T, D = 2, 4, 16384
|
| 357 |
+
n_columns = 2048
|
| 358 |
+
target_active_in = int(D * 0.02) # 327
|
| 359 |
+
|
| 360 |
+
layer = HTMLayer(
|
| 361 |
+
input_bits=D,
|
| 362 |
+
n_columns=n_columns,
|
| 363 |
+
cells_per_column=32,
|
| 364 |
+
batch_size=B,
|
| 365 |
+
seed=42,
|
| 366 |
+
learn=True,
|
| 367 |
+
)
|
| 368 |
+
layer.train()
|
| 369 |
+
|
| 370 |
+
rng = np.random.default_rng(0)
|
| 371 |
+
sdr = np.zeros((B, T, D), dtype=bool)
|
| 372 |
+
for b in range(B):
|
| 373 |
+
for t in range(T):
|
| 374 |
+
idx = rng.choice(D, size=target_active_in, replace=False)
|
| 375 |
+
sdr[b, t, idx] = True
|
| 376 |
+
sdr_t = torch.from_numpy(sdr)
|
| 377 |
+
|
| 378 |
+
t0 = time.perf_counter()
|
| 379 |
+
out = layer(sdr_t)
|
| 380 |
+
dt_first = time.perf_counter() - t0
|
| 381 |
+
|
| 382 |
+
assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}"
|
| 383 |
+
assert out.dtype == torch.float32, f"dtype {out.dtype}"
|
| 384 |
+
|
| 385 |
+
active_cols = out[..., :n_columns]
|
| 386 |
+
anomaly = out[..., n_columns]
|
| 387 |
+
|
| 388 |
+
col_sums = active_cols.sum(dim=-1) # (B, T)
|
| 389 |
+
mean_active = col_sums.float().mean().item()
|
| 390 |
+
expected = n_columns * 0.02 # ≈ 40.96
|
| 391 |
+
assert 20 <= mean_active <= 60, (
|
| 392 |
+
f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# t=0 has no TM prediction → anomaly = 1.0 on every batch slot.
|
| 396 |
+
assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}"
|
| 397 |
+
|
| 398 |
+
# Second forward on same (reset) layer: identical shapes, deterministic re-run possible.
|
| 399 |
+
t0 = time.perf_counter()
|
| 400 |
+
out2 = layer(sdr_t)
|
| 401 |
+
dt_second = time.perf_counter() - t0
|
| 402 |
+
assert out2.shape == out.shape
|
| 403 |
+
|
| 404 |
+
# Repeating-sequence anomaly decay check — one region, T=8 repeats of same pattern.
|
| 405 |
+
rep_layer = HTMLayer(
|
| 406 |
+
input_bits=D,
|
| 407 |
+
n_columns=n_columns,
|
| 408 |
+
batch_size=1,
|
| 409 |
+
seed=7,
|
| 410 |
+
learn=True,
|
| 411 |
+
)
|
| 412 |
+
rep_layer.train()
|
| 413 |
+
base = torch.zeros(D, dtype=torch.bool)
|
| 414 |
+
idx = rng.choice(D, size=target_active_in, replace=False)
|
| 415 |
+
base[idx] = True
|
| 416 |
+
rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone()
|
| 417 |
+
rep_out = rep_layer(rep)
|
| 418 |
+
rep_anom = rep_out[0, :, n_columns]
|
| 419 |
+
assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}"
|
| 420 |
+
assert rep_anom[-1].item() < rep_anom[0].item(), (
|
| 421 |
+
f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
print("[OK] shape:", tuple(out.shape))
|
| 425 |
+
print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})")
|
| 426 |
+
print(f"[OK] t=0 anomaly = 1.0 on all batch slots")
|
| 427 |
+
print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}")
|
| 428 |
+
print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms "
|
| 429 |
+
f"on (B={B}, T={T}, D={D})")
|
overlay/subsystems/sdr_retina.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Offline Semantic Folding SDR Retina (Cortical.io-grade).
|
| 3 |
+
|
| 4 |
+
Builds a topographic, semantic-folding Sparse Distributed Representation (SDR)
|
| 5 |
+
for every token in the vocabulary, following Webber 2015 ("Semantic Folding Theory").
|
| 6 |
+
|
| 7 |
+
Pipeline:
|
| 8 |
+
1. Scan the tokenized training corpus (parquet shards at ~/.cache/autoresearch/data).
|
| 9 |
+
We on-the-fly tokenize ~10M tokens from the first few shards.
|
| 10 |
+
2. For each token, build a context vector = top-K most-associated neighbors
|
| 11 |
+
(±8-token window, PMI ranking).
|
| 12 |
+
3. Train a 128x128 = 16384-bit Kohonen SOM on those context vectors so that
|
| 13 |
+
semantically related context features land on neighboring lattice cells.
|
| 14 |
+
4. For each token, compute its folded SDR: union of the lattice cells whose
|
| 15 |
+
BMUs are triggered by its top-K context features. Then per-row quantile
|
| 16 |
+
threshold to exactly 2% active bits (327 / 16384).
|
| 17 |
+
5. Save to ~/.cache/autoresearch/retina.npz.
|
| 18 |
+
|
| 19 |
+
Entry point:
|
| 20 |
+
uv run python subsystems/sdr_retina.py --build --validate
|
| 21 |
+
|
| 22 |
+
The validation asserts classic Cortical.io-style analogies:
|
| 23 |
+
- overlap("the", "a") > overlap("the", "zebra")
|
| 24 |
+
- overlap("man", "woman") > overlap("man", "rock")
|
| 25 |
+
- overlap("king","queen") > overlap("king", "dinosaur")
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import math
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
import time
|
| 35 |
+
from dataclasses import dataclass
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
import pyarrow.parquet as pq
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
# Make the parent repo importable so we can reuse the Tokenizer
|
| 42 |
+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 43 |
+
sys.path.insert(0, REPO_ROOT)
|
| 44 |
+
|
| 45 |
+
from prepare import CACHE_DIR, DATA_DIR, TOKENIZER_DIR, VAL_FILENAME, Tokenizer # noqa: E402
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Build parameters
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
RETINA_PATH = os.path.join(CACHE_DIR, "retina.npz")
|
| 53 |
+
|
| 54 |
+
GRID_H = 128
|
| 55 |
+
GRID_W = 128
|
| 56 |
+
N_BITS = GRID_H * GRID_W # 16384
|
| 57 |
+
TARGET_SPARSITY = 0.02 # 2% (default, Cortical.io-style)
|
| 58 |
+
# Default = int(floor(N_BITS * TARGET_SPARSITY)) = 327, matches Webber/Numenta.
|
| 59 |
+
# Override via HYDRA_SDR_TARGET_ACTIVE env var. The cache key encodes
|
| 60 |
+
# target_active, so changing this triggers automatic retina regeneration.
|
| 61 |
+
TARGET_ACTIVE = int(os.environ.get(
|
| 62 |
+
"HYDRA_SDR_TARGET_ACTIVE",
|
| 63 |
+
str(int(N_BITS * TARGET_SPARSITY)),
|
| 64 |
+
))
|
| 65 |
+
|
| 66 |
+
CONTEXT_WINDOW = 8 # +/- 8 tokens
|
| 67 |
+
TOP_K_FEATURES = 64 # top-K context features per token
|
| 68 |
+
# SCALES WITH VOCAB — need ~100+ occurrences per token for stable cooccurrence.
|
| 69 |
+
# At V=8k: 10M tokens = 1250/tok avg. At V=65k: 10M tokens = 153/tok avg
|
| 70 |
+
# (borderline); rare tokens seen <30x → noisy retina. Recommended: V*150.
|
| 71 |
+
# HF Hub cache makes this a one-time cost per vocab config anyway.
|
| 72 |
+
TARGET_TRAIN_TOKENS = int(os.environ.get("HYDRA_RETINA_TRAIN_TOKENS", "20000000"))
|
| 73 |
+
MAX_DOCS_PER_SHARD = 200_000 # safety cap per shard
|
| 74 |
+
|
| 75 |
+
# Kohonen SOM
|
| 76 |
+
SOM_EPOCHS = 50
|
| 77 |
+
SOM_SIGMA_START = 32.0
|
| 78 |
+
SOM_SIGMA_END = 1.0
|
| 79 |
+
SOM_ALPHA_START = 0.1
|
| 80 |
+
SOM_ALPHA_END = 0.001
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Small helpers
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
def _fmt(n):
|
| 88 |
+
if n >= 1_000_000:
|
| 89 |
+
return f"{n/1_000_000:.2f}M"
|
| 90 |
+
if n >= 1_000:
|
| 91 |
+
return f"{n/1_000:.1f}k"
|
| 92 |
+
return str(n)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _device() -> torch.device:
|
| 96 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _list_train_shards():
|
| 100 |
+
files = sorted(
|
| 101 |
+
f for f in os.listdir(DATA_DIR)
|
| 102 |
+
if f.endswith(".parquet") and not f.endswith(".tmp")
|
| 103 |
+
)
|
| 104 |
+
train = [os.path.join(DATA_DIR, f) for f in files if f != VAL_FILENAME]
|
| 105 |
+
assert len(train) > 0, f"No training shards at {DATA_DIR}. Run prepare.py first."
|
| 106 |
+
return train
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# Stage 1: stream tokens from parquet shards and collect co-occurrences
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def _iter_tokenized_shards(tokenizer: Tokenizer, target_tokens: int):
|
| 114 |
+
"""Yield 1-D int32 numpy arrays of token ids until target_tokens reached.
|
| 115 |
+
|
| 116 |
+
Two paths:
|
| 117 |
+
- HYDRA_USE_NEMOTRON=1: stream docs from Nemotron HF datasets (no shards
|
| 118 |
+
on disk — matches the streaming training path).
|
| 119 |
+
- Default: iterate local parquet shards (legacy prepare.py path).
|
| 120 |
+
"""
|
| 121 |
+
tok_encode = tokenizer.enc.encode_ordinary_batch
|
| 122 |
+
|
| 123 |
+
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 124 |
+
# Streaming path: reuse prepare_nemotron's weighted stream.
|
| 125 |
+
import prepare_nemotron as _pn
|
| 126 |
+
stream = _pn._WeightedStream(_pn._phase_weights(), seed=0)
|
| 127 |
+
seen = 0
|
| 128 |
+
batch: list[str] = []
|
| 129 |
+
BATCH = 512
|
| 130 |
+
while seen < target_tokens:
|
| 131 |
+
text, _epoch = next(stream)
|
| 132 |
+
if not text:
|
| 133 |
+
continue
|
| 134 |
+
batch.append(text)
|
| 135 |
+
if len(batch) < BATCH:
|
| 136 |
+
continue
|
| 137 |
+
token_lists = tok_encode(batch, num_threads=8)
|
| 138 |
+
batch = []
|
| 139 |
+
for ids in token_lists:
|
| 140 |
+
if not ids:
|
| 141 |
+
continue
|
| 142 |
+
arr = np.asarray(ids, dtype=np.int32)
|
| 143 |
+
yield arr
|
| 144 |
+
seen += arr.size
|
| 145 |
+
if seen >= target_tokens:
|
| 146 |
+
print(f" [nemotron-stream] yielded {_fmt(seen)} tokens, target reached")
|
| 147 |
+
return
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
# Legacy shard path.
|
| 151 |
+
shards = _list_train_shards()
|
| 152 |
+
seen = 0
|
| 153 |
+
for shard_idx, path in enumerate(shards):
|
| 154 |
+
if seen >= target_tokens:
|
| 155 |
+
return
|
| 156 |
+
pf = pq.ParquetFile(path)
|
| 157 |
+
shard_tokens = 0
|
| 158 |
+
for rg_idx in range(pf.num_row_groups):
|
| 159 |
+
rg = pf.read_row_group(rg_idx)
|
| 160 |
+
docs = rg.column("text").to_pylist()
|
| 161 |
+
if len(docs) > MAX_DOCS_PER_SHARD:
|
| 162 |
+
docs = docs[:MAX_DOCS_PER_SHARD]
|
| 163 |
+
# Batch-encode for throughput
|
| 164 |
+
batch_size = 512
|
| 165 |
+
for i in range(0, len(docs), batch_size):
|
| 166 |
+
batch = docs[i:i + batch_size]
|
| 167 |
+
token_lists = tok_encode(batch, num_threads=8)
|
| 168 |
+
for ids in token_lists:
|
| 169 |
+
if not ids:
|
| 170 |
+
continue
|
| 171 |
+
arr = np.asarray(ids, dtype=np.int32)
|
| 172 |
+
yield arr
|
| 173 |
+
shard_tokens += arr.size
|
| 174 |
+
seen += arr.size
|
| 175 |
+
if seen >= target_tokens:
|
| 176 |
+
print(f" shard {shard_idx}: yielded {_fmt(shard_tokens)} tokens "
|
| 177 |
+
f"(total {_fmt(seen)}), target reached")
|
| 178 |
+
return
|
| 179 |
+
print(f" shard {shard_idx}: yielded {_fmt(shard_tokens)} tokens (total {_fmt(seen)})")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _cooccur_from_doc(ids: np.ndarray, window: int, vocab_size: int,
|
| 183 |
+
counts: np.ndarray, cooc: np.ndarray) -> None:
|
| 184 |
+
"""Update unigram counts and cooccurrence counts for one document. Vectorized."""
|
| 185 |
+
n = ids.size
|
| 186 |
+
if n < 2:
|
| 187 |
+
return
|
| 188 |
+
# unigram counts
|
| 189 |
+
np.add.at(counts, ids, 1)
|
| 190 |
+
# For each offset d in 1..window, count pairs (ids[:-d], ids[d:])
|
| 191 |
+
# Both directions are equivalent by symmetry; we add both to keep the
|
| 192 |
+
# matrix symmetric and treat it as undirected context.
|
| 193 |
+
for d in range(1, window + 1):
|
| 194 |
+
left = ids[:-d]
|
| 195 |
+
right = ids[d:]
|
| 196 |
+
# symmetric update
|
| 197 |
+
flat_lr = left.astype(np.int64) * vocab_size + right.astype(np.int64)
|
| 198 |
+
flat_rl = right.astype(np.int64) * vocab_size + left.astype(np.int64)
|
| 199 |
+
# use bincount-style scatter via np.add.at on the flat view
|
| 200 |
+
cooc_flat = cooc.ravel()
|
| 201 |
+
np.add.at(cooc_flat, flat_lr, 1)
|
| 202 |
+
np.add.at(cooc_flat, flat_rl, 1)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def build_cooccurrence(tokenizer: Tokenizer, target_tokens: int, window: int) -> tuple[np.ndarray, np.ndarray, int]:
|
| 206 |
+
"""
|
| 207 |
+
Stream tokens and build unigram + cooccurrence counts.
|
| 208 |
+
Returns (counts[V] int64, cooc[V,V] int32, total_tokens int).
|
| 209 |
+
"""
|
| 210 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 211 |
+
print(f"[1/4] Building cooccurrence (vocab={vocab_size}, window=+/-{window}, target={_fmt(target_tokens)} tokens)")
|
| 212 |
+
counts = np.zeros(vocab_size, dtype=np.int64)
|
| 213 |
+
# int32 is enough per-cell if we stay <= a few hundred million total tokens; guard with clip at save.
|
| 214 |
+
cooc = np.zeros((vocab_size, vocab_size), dtype=np.int32)
|
| 215 |
+
|
| 216 |
+
total = 0
|
| 217 |
+
n_docs = 0
|
| 218 |
+
t0 = time.time()
|
| 219 |
+
for ids in _iter_tokenized_shards(tokenizer, target_tokens):
|
| 220 |
+
_cooccur_from_doc(ids, window, vocab_size, counts, cooc)
|
| 221 |
+
total += ids.size
|
| 222 |
+
n_docs += 1
|
| 223 |
+
if n_docs % 5000 == 0:
|
| 224 |
+
dt = time.time() - t0
|
| 225 |
+
rate = total / max(dt, 1e-6)
|
| 226 |
+
print(f" docs={_fmt(n_docs)} tokens={_fmt(total)} ({rate/1000:.0f}k tok/s)")
|
| 227 |
+
|
| 228 |
+
dt = time.time() - t0
|
| 229 |
+
print(f"[1/4] done: {_fmt(total)} tokens, {_fmt(n_docs)} docs, {dt:.1f}s, "
|
| 230 |
+
f"cooc_nnz={int((cooc > 0).sum())}")
|
| 231 |
+
return counts, cooc, total
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ---------------------------------------------------------------------------
|
| 235 |
+
# Stage 2: build top-K context features (PMI)
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
|
| 238 |
+
def compute_pmi_topk(counts: np.ndarray, cooc: np.ndarray, total_tokens: int,
|
| 239 |
+
top_k: int) -> tuple[np.ndarray, np.ndarray]:
|
| 240 |
+
"""
|
| 241 |
+
For each token, compute top-K context features by positive PMI.
|
| 242 |
+
Returns:
|
| 243 |
+
topk_idx : int32 [V, K] token ids of the top-K context features
|
| 244 |
+
topk_score : float32 [V, K] PMI scores (0 for padded missing features)
|
| 245 |
+
Missing features are padded with idx=token itself and score=0, so they
|
| 246 |
+
have a well-defined (but uninformative) column.
|
| 247 |
+
"""
|
| 248 |
+
V = counts.shape[0]
|
| 249 |
+
print(f"[2/4] Computing PMI top-{top_k} per token (vocab={V})")
|
| 250 |
+
|
| 251 |
+
# window_pairs per occurrence: 2 * window (we added both directions, each offset twice).
|
| 252 |
+
# For the PMI denominator we need a total pair count; using coo.sum() is the clean
|
| 253 |
+
# per-matrix normalizer and avoids any constant confusion.
|
| 254 |
+
pair_total = float(cooc.sum())
|
| 255 |
+
if pair_total <= 0:
|
| 256 |
+
raise RuntimeError("Empty cooccurrence matrix")
|
| 257 |
+
|
| 258 |
+
# Run on GPU if available; this is ~8k x 8k float32 = 256MB each.
|
| 259 |
+
dev = _device()
|
| 260 |
+
cooc_t = torch.from_numpy(cooc.astype(np.float32)).to(dev)
|
| 261 |
+
counts_t = torch.from_numpy(counts.astype(np.float64)).to(dev).clamp_min(1.0)
|
| 262 |
+
|
| 263 |
+
# P(i) = counts[i] / total_tokens
|
| 264 |
+
# P(i, j) = cooc[i, j] / pair_total
|
| 265 |
+
# PMI = log(P(i,j) / (P(i) P(j)))
|
| 266 |
+
# Positive PMI = max(PMI, 0).
|
| 267 |
+
# We'll compute log-PMI in a numerically safe way:
|
| 268 |
+
# log(cooc) + log(total_tokens^2 / pair_total) - log(c_i) - log(c_j)
|
| 269 |
+
# Keep numerator zero where cooc==0 and mask those out.
|
| 270 |
+
|
| 271 |
+
log_const = math.log(total_tokens) + math.log(total_tokens) - math.log(pair_total)
|
| 272 |
+
log_ci = torch.log(counts_t) # [V]
|
| 273 |
+
log_cj = log_ci.clone() # same vector (symmetric vocab)
|
| 274 |
+
|
| 275 |
+
# We'll do it in row blocks to cap memory of intermediate log() tensors.
|
| 276 |
+
topk_idx = np.zeros((V, top_k), dtype=np.int32)
|
| 277 |
+
topk_score = np.zeros((V, top_k), dtype=np.float32)
|
| 278 |
+
|
| 279 |
+
block = 512
|
| 280 |
+
t0 = time.time()
|
| 281 |
+
for start in range(0, V, block):
|
| 282 |
+
end = min(V, start + block)
|
| 283 |
+
rows = cooc_t[start:end] # [b, V] int-as-float
|
| 284 |
+
mask = rows > 0
|
| 285 |
+
# log(rows) where rows>0; else keep -inf then mask out
|
| 286 |
+
log_rows = torch.where(mask, torch.log(rows.clamp_min(1.0)),
|
| 287 |
+
torch.full_like(rows, float("-inf")))
|
| 288 |
+
pmi = log_rows + log_const - log_ci[start:end].unsqueeze(1) - log_cj.unsqueeze(0)
|
| 289 |
+
ppmi = torch.where(mask, torch.clamp(pmi, min=0.0),
|
| 290 |
+
torch.full_like(pmi, float("-inf")))
|
| 291 |
+
# top-K along dim=1
|
| 292 |
+
vals, idx = torch.topk(ppmi, k=top_k, dim=1)
|
| 293 |
+
# Replace any -inf valued slots with score 0 and idx = the token itself
|
| 294 |
+
bad = torch.isneginf(vals)
|
| 295 |
+
if bad.any():
|
| 296 |
+
self_idx = torch.arange(start, end, device=dev).unsqueeze(1).expand_as(idx)
|
| 297 |
+
idx = torch.where(bad, self_idx, idx)
|
| 298 |
+
vals = torch.where(bad, torch.zeros_like(vals), vals)
|
| 299 |
+
topk_idx[start:end] = idx.cpu().numpy().astype(np.int32)
|
| 300 |
+
topk_score[start:end] = vals.cpu().numpy().astype(np.float32)
|
| 301 |
+
|
| 302 |
+
del cooc_t, counts_t
|
| 303 |
+
if dev.type == "cuda":
|
| 304 |
+
torch.cuda.empty_cache()
|
| 305 |
+
print(f"[2/4] done: top-{top_k} PMI features per token in {time.time()-t0:.1f}s")
|
| 306 |
+
return topk_idx, topk_score
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ---------------------------------------------------------------------------
|
| 310 |
+
# Stage 3: Kohonen SOM on the context-vector representation
|
| 311 |
+
# ---------------------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
def _context_vectors_from_topk(topk_idx: np.ndarray, topk_score: np.ndarray,
|
| 314 |
+
vocab_size: int) -> torch.Tensor:
|
| 315 |
+
"""
|
| 316 |
+
Build the dense context matrix X [V, V] where X[i] is the top-K PMI context
|
| 317 |
+
vector for token i, L2-normalized. For V=8192 this is 8k x 8k float32 = 256 MB.
|
| 318 |
+
"""
|
| 319 |
+
V = vocab_size
|
| 320 |
+
K = topk_idx.shape[1]
|
| 321 |
+
dev = _device()
|
| 322 |
+
X = torch.zeros((V, V), dtype=torch.float32, device=dev)
|
| 323 |
+
rows = torch.arange(V, device=dev).unsqueeze(1).expand(V, K) # [V,K]
|
| 324 |
+
idx = torch.from_numpy(topk_idx).to(dev).long()
|
| 325 |
+
scores = torch.from_numpy(topk_score).to(dev)
|
| 326 |
+
# Scatter scores into X at positions (rows, idx). If duplicates, keep max.
|
| 327 |
+
X[rows, idx] = torch.maximum(X[rows, idx], scores)
|
| 328 |
+
# L2 normalize so Euclidean ~ cosine
|
| 329 |
+
norm = X.norm(dim=1, keepdim=True).clamp_min(1e-8)
|
| 330 |
+
X = X / norm
|
| 331 |
+
return X
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def train_som(X: torch.Tensor, grid_h: int, grid_w: int,
|
| 335 |
+
epochs: int, sigma_start: float, sigma_end: float,
|
| 336 |
+
alpha_start: float, alpha_end: float,
|
| 337 |
+
seed: int = 137) -> torch.Tensor:
|
| 338 |
+
"""
|
| 339 |
+
Train a Kohonen SOM with rectangular grid and Gaussian neighborhood.
|
| 340 |
+
X: [V, F] features (L2 normalized). Returns weights W: [grid_h*grid_w, F].
|
| 341 |
+
"""
|
| 342 |
+
dev = X.device
|
| 343 |
+
V, F = X.shape
|
| 344 |
+
N = grid_h * grid_w
|
| 345 |
+
|
| 346 |
+
torch.manual_seed(seed)
|
| 347 |
+
# Initialize SOM weights: small random linear combinations of data points
|
| 348 |
+
# (faster convergence than uniform random in the feature space).
|
| 349 |
+
init_pick = torch.randint(0, V, (N,), device=dev)
|
| 350 |
+
W = X[init_pick].clone() # [N, F]
|
| 351 |
+
|
| 352 |
+
# Precompute grid coordinates
|
| 353 |
+
yy, xx = torch.meshgrid(
|
| 354 |
+
torch.arange(grid_h, device=dev, dtype=torch.float32),
|
| 355 |
+
torch.arange(grid_w, device=dev, dtype=torch.float32),
|
| 356 |
+
indexing="ij",
|
| 357 |
+
)
|
| 358 |
+
grid = torch.stack([yy.reshape(-1), xx.reshape(-1)], dim=1) # [N, 2]
|
| 359 |
+
|
| 360 |
+
print(f"[3/4] Training Kohonen SOM: grid={grid_h}x{grid_w}, features={F}, "
|
| 361 |
+
f"epochs={epochs}, sigma {sigma_start}->{sigma_end}, alpha {alpha_start}->{alpha_end}")
|
| 362 |
+
t0 = time.time()
|
| 363 |
+
|
| 364 |
+
# Exponential decay schedules
|
| 365 |
+
def schedule(t_frac):
|
| 366 |
+
sigma = sigma_start * (sigma_end / sigma_start) ** t_frac
|
| 367 |
+
alpha = alpha_start * (alpha_end / alpha_start) ** t_frac
|
| 368 |
+
return sigma, alpha
|
| 369 |
+
|
| 370 |
+
# Batch-mode SOM: process a random permutation each epoch in mini-batches.
|
| 371 |
+
# For each mini-batch, compute BMUs then one vectorized neighborhood update.
|
| 372 |
+
batch_size = 256
|
| 373 |
+
|
| 374 |
+
for epoch in range(epochs):
|
| 375 |
+
t_frac = epoch / max(epochs - 1, 1)
|
| 376 |
+
sigma, alpha = schedule(t_frac)
|
| 377 |
+
two_sigma2 = 2.0 * sigma * sigma
|
| 378 |
+
perm = torch.randperm(V, device=dev)
|
| 379 |
+
|
| 380 |
+
for bstart in range(0, V, batch_size):
|
| 381 |
+
bidx = perm[bstart:bstart + batch_size]
|
| 382 |
+
xb = X[bidx] # [b, F]
|
| 383 |
+
# BMU: argmax of cosine similarity = argmin of squared Euclidean
|
| 384 |
+
# ||x||=||w||=1 for data; W may drift but the formulation remains stable.
|
| 385 |
+
sim = xb @ W.t() # [b, N]
|
| 386 |
+
bmu = sim.argmax(dim=1) # [b]
|
| 387 |
+
|
| 388 |
+
# Neighborhood weights h[b, n] = exp(-|grid[bmu_b] - grid[n]|^2 / (2*sigma^2))
|
| 389 |
+
bmu_coords = grid[bmu] # [b, 2]
|
| 390 |
+
diff = bmu_coords.unsqueeze(1) - grid.unsqueeze(0) # [b, N, 2]
|
| 391 |
+
dist2 = (diff * diff).sum(dim=2) # [b, N]
|
| 392 |
+
h = torch.exp(-dist2 / two_sigma2) # [b, N]
|
| 393 |
+
h = h * alpha # include LR
|
| 394 |
+
|
| 395 |
+
# Vectorized SOM update:
|
| 396 |
+
# W <- W + sum_b h[b] * (x_b - W) / (sum_b h[b])
|
| 397 |
+
# Batched form: numerator = h^T x_b [N, F], denom = h.sum(0) [N]
|
| 398 |
+
numer = h.t() @ xb # [N, F]
|
| 399 |
+
denom = h.sum(dim=0).unsqueeze(1).clamp_min(1e-8) # [N, 1]
|
| 400 |
+
target = numer / denom
|
| 401 |
+
# Update weight: mix toward target with a unit step (h already scaled by alpha).
|
| 402 |
+
# To prevent over-shoot when the same BMU is hit heavily, scale by the
|
| 403 |
+
# mean-field gain min(1, denom). Empirically this behaves like classic SOM.
|
| 404 |
+
gain = torch.clamp(h.sum(dim=0), max=1.0).unsqueeze(1) # [N,1]
|
| 405 |
+
W = (1 - gain) * W + gain * target
|
| 406 |
+
|
| 407 |
+
# Renormalize weights to unit sphere for stability
|
| 408 |
+
W = W / W.norm(dim=1, keepdim=True).clamp_min(1e-8)
|
| 409 |
+
|
| 410 |
+
if (epoch + 1) % max(1, epochs // 10) == 0 or epoch == 0:
|
| 411 |
+
dt = time.time() - t0
|
| 412 |
+
print(f" epoch {epoch+1}/{epochs} sigma={sigma:.2f} alpha={alpha:.4f} elapsed={dt:.1f}s")
|
| 413 |
+
|
| 414 |
+
print(f"[3/4] SOM trained in {time.time()-t0:.1f}s")
|
| 415 |
+
return W
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# ---------------------------------------------------------------------------
|
| 419 |
+
# Stage 4: fold context vectors into SDRs
|
| 420 |
+
# ---------------------------------------------------------------------------
|
| 421 |
+
|
| 422 |
+
def fold_sdrs(X: torch.Tensor, W: torch.Tensor, topk_idx: np.ndarray,
|
| 423 |
+
topk_score: np.ndarray, target_active: int) -> np.ndarray:
|
| 424 |
+
"""
|
| 425 |
+
For each token, activate the 'cell votes' on the lattice for each of its top-K
|
| 426 |
+
context features, then threshold to exactly target_active bits.
|
| 427 |
+
|
| 428 |
+
Implementation detail: every token in the vocabulary has a SOM BMU given its
|
| 429 |
+
context vector X[i]. We use those BMUs as the feature->cell map. For token t,
|
| 430 |
+
we accumulate votes at BMU(feature) weighted by the PMI score, then pick the
|
| 431 |
+
top target_active cells.
|
| 432 |
+
"""
|
| 433 |
+
dev = X.device
|
| 434 |
+
V, F = X.shape
|
| 435 |
+
N = W.shape[0]
|
| 436 |
+
print(f"[4/4] Folding SDRs (V={V}, N={N}, target_active={target_active})")
|
| 437 |
+
|
| 438 |
+
# Per-feature BMU: for each token f as a feature, BMU_f = argmax_n W[n] . X[f]
|
| 439 |
+
# Chunked matmul to bound memory.
|
| 440 |
+
bmu = torch.empty(V, dtype=torch.long, device=dev)
|
| 441 |
+
chunk = 1024
|
| 442 |
+
for s in range(0, V, chunk):
|
| 443 |
+
e = min(V, s + chunk)
|
| 444 |
+
sim = X[s:e] @ W.t() # [b, N]
|
| 445 |
+
bmu[s:e] = sim.argmax(dim=1)
|
| 446 |
+
|
| 447 |
+
# Now build votes tensor [V, N] = sum over k of score[i, k] delta(n = bmu[feat[i, k]])
|
| 448 |
+
K = topk_idx.shape[1]
|
| 449 |
+
feat = torch.from_numpy(topk_idx).to(dev).long() # [V, K]
|
| 450 |
+
sc = torch.from_numpy(topk_score).to(dev) # [V, K]
|
| 451 |
+
feat_bmu = bmu[feat] # [V, K]
|
| 452 |
+
|
| 453 |
+
votes = torch.zeros((V, N), dtype=torch.float32, device=dev)
|
| 454 |
+
votes.scatter_add_(1, feat_bmu, sc)
|
| 455 |
+
|
| 456 |
+
# Tiny numerical nudge: add a local Gaussian kernel around each voted cell so
|
| 457 |
+
# near-neighbors accumulate mass (this is the "folding" smear). Kernel radius 1.
|
| 458 |
+
# Implement as a separable 3x3 blur on the 2D grid view.
|
| 459 |
+
grid_h = int(round(math.sqrt(N)))
|
| 460 |
+
grid_w = grid_h
|
| 461 |
+
assert grid_h * grid_w == N
|
| 462 |
+
votes_2d = votes.view(V, 1, grid_h, grid_w)
|
| 463 |
+
blur = torch.tensor([[[[0.5, 1.0, 0.5],
|
| 464 |
+
[1.0, 2.0, 1.0],
|
| 465 |
+
[0.5, 1.0, 0.5]]]], device=dev, dtype=torch.float32)
|
| 466 |
+
blur = blur / blur.sum()
|
| 467 |
+
votes_2d = torch.nn.functional.conv2d(votes_2d, blur, padding=1)
|
| 468 |
+
votes = votes_2d.view(V, N)
|
| 469 |
+
|
| 470 |
+
# Per-row top-target_active
|
| 471 |
+
_, top_cells = torch.topk(votes, k=target_active, dim=1)
|
| 472 |
+
sdr = torch.zeros((V, N), dtype=torch.bool, device=dev)
|
| 473 |
+
sdr.scatter_(1, top_cells, True)
|
| 474 |
+
|
| 475 |
+
# Sanity check
|
| 476 |
+
row_active = sdr.sum(dim=1)
|
| 477 |
+
assert int(row_active.min()) == target_active, "row active mismatch"
|
| 478 |
+
assert int(row_active.max()) == target_active, "row active mismatch"
|
| 479 |
+
|
| 480 |
+
return sdr.cpu().numpy()
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# ---------------------------------------------------------------------------
|
| 484 |
+
# Build orchestration
|
| 485 |
+
# ---------------------------------------------------------------------------
|
| 486 |
+
|
| 487 |
+
@dataclass
|
| 488 |
+
class BuildReport:
|
| 489 |
+
vocab_size: int
|
| 490 |
+
n_bits: int
|
| 491 |
+
train_tokens: int
|
| 492 |
+
wall_time_sec: float
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _retina_cache_repo() -> str:
|
| 496 |
+
return os.environ.get("HYDRA_RETINA_CACHE_REPO", "icarus112/feather-retina-cache")
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _retina_cache_key() -> str:
|
| 500 |
+
"""Cache key encodes vocab_size + n_bits + target_active so we don't
|
| 501 |
+
accidentally restore a retina built for a different tokenizer/config."""
|
| 502 |
+
try:
|
| 503 |
+
from prepare import VOCAB_SIZE
|
| 504 |
+
except Exception:
|
| 505 |
+
VOCAB_SIZE = 0
|
| 506 |
+
return f"retina_v{VOCAB_SIZE}_n{N_BITS}_a{TARGET_ACTIVE}.npz"
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _try_hydrate_retina_from_hub() -> bool:
|
| 510 |
+
"""Attempt to download a pre-built retina matching our config from HF Hub.
|
| 511 |
+
Returns True if successful — caller should skip the rebuild."""
|
| 512 |
+
token = os.environ.get("HF_TOKEN")
|
| 513 |
+
if not token:
|
| 514 |
+
return False
|
| 515 |
+
cache_key = _retina_cache_key()
|
| 516 |
+
try:
|
| 517 |
+
from huggingface_hub import hf_hub_download
|
| 518 |
+
p = hf_hub_download(
|
| 519 |
+
repo_id=_retina_cache_repo(), repo_type="dataset",
|
| 520 |
+
filename=cache_key, token=token,
|
| 521 |
+
)
|
| 522 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 523 |
+
import shutil
|
| 524 |
+
shutil.copy(p, RETINA_PATH)
|
| 525 |
+
# Quick verify shape
|
| 526 |
+
with np.load(RETINA_PATH) as npz:
|
| 527 |
+
if int(npz["n_bits"]) == N_BITS and int(npz["target_active"]) == TARGET_ACTIVE:
|
| 528 |
+
print(f"[retina-cache] hydrated {cache_key} from {_retina_cache_repo()} "
|
| 529 |
+
f"(shape={npz['sdr'].shape})", flush=True)
|
| 530 |
+
return True
|
| 531 |
+
os.remove(RETINA_PATH)
|
| 532 |
+
return False
|
| 533 |
+
except Exception as e:
|
| 534 |
+
print(f"[retina-cache] miss: {e}", flush=True)
|
| 535 |
+
return False
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _upload_retina_to_hub() -> None:
|
| 539 |
+
"""Upload freshly-built retina.npz to HF Hub for reuse by future jobs."""
|
| 540 |
+
token = os.environ.get("HF_TOKEN")
|
| 541 |
+
if not token:
|
| 542 |
+
return
|
| 543 |
+
cache_key = _retina_cache_key()
|
| 544 |
+
try:
|
| 545 |
+
from huggingface_hub import HfApi, create_repo
|
| 546 |
+
create_repo(_retina_cache_repo(), repo_type="dataset", private=True,
|
| 547 |
+
exist_ok=True, token=token)
|
| 548 |
+
HfApi(token=token).upload_file(
|
| 549 |
+
path_or_fileobj=RETINA_PATH,
|
| 550 |
+
path_in_repo=cache_key,
|
| 551 |
+
repo_id=_retina_cache_repo(), repo_type="dataset",
|
| 552 |
+
commit_message=f"retina build for {cache_key}", token=token,
|
| 553 |
+
)
|
| 554 |
+
print(f"[retina-cache] uploaded {cache_key} to {_retina_cache_repo()}", flush=True)
|
| 555 |
+
except Exception as e:
|
| 556 |
+
print(f"[retina-cache] upload failed: {e}", flush=True)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def build_retina(target_tokens: int = TARGET_TRAIN_TOKENS) -> BuildReport:
|
| 560 |
+
# Try HF Hub-backed cache first — retina build takes 500+ seconds.
|
| 561 |
+
if os.path.exists(RETINA_PATH):
|
| 562 |
+
print(f"[retina-cache] using local {RETINA_PATH}", flush=True)
|
| 563 |
+
with np.load(RETINA_PATH) as npz:
|
| 564 |
+
return BuildReport(
|
| 565 |
+
vocab_size=int(npz["vocab_size"]),
|
| 566 |
+
n_bits=int(npz["n_bits"]),
|
| 567 |
+
train_tokens=int(npz["train_tokens"]),
|
| 568 |
+
wall_time_sec=0.0,
|
| 569 |
+
)
|
| 570 |
+
elif _try_hydrate_retina_from_hub():
|
| 571 |
+
# Local copy now populated; return stub report
|
| 572 |
+
with np.load(RETINA_PATH) as npz:
|
| 573 |
+
return BuildReport(
|
| 574 |
+
vocab_size=int(npz["vocab_size"]),
|
| 575 |
+
n_bits=int(npz["n_bits"]),
|
| 576 |
+
train_tokens=int(npz["train_tokens"]),
|
| 577 |
+
wall_time_sec=0.0,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
tokenizer = Tokenizer.from_directory(TOKENIZER_DIR)
|
| 581 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 582 |
+
|
| 583 |
+
t0 = time.time()
|
| 584 |
+
|
| 585 |
+
counts, cooc, total_tokens = build_cooccurrence(
|
| 586 |
+
tokenizer, target_tokens=target_tokens, window=CONTEXT_WINDOW,
|
| 587 |
+
)
|
| 588 |
+
topk_idx, topk_score = compute_pmi_topk(
|
| 589 |
+
counts, cooc, total_tokens=total_tokens, top_k=TOP_K_FEATURES,
|
| 590 |
+
)
|
| 591 |
+
# Free the big cooccurrence matrix before GPU-heavy stages
|
| 592 |
+
del cooc
|
| 593 |
+
X = _context_vectors_from_topk(topk_idx, topk_score, vocab_size)
|
| 594 |
+
W = train_som(
|
| 595 |
+
X, grid_h=GRID_H, grid_w=GRID_W,
|
| 596 |
+
epochs=SOM_EPOCHS,
|
| 597 |
+
sigma_start=SOM_SIGMA_START, sigma_end=SOM_SIGMA_END,
|
| 598 |
+
alpha_start=SOM_ALPHA_START, alpha_end=SOM_ALPHA_END,
|
| 599 |
+
)
|
| 600 |
+
sdr = fold_sdrs(X, W, topk_idx, topk_score, target_active=TARGET_ACTIVE)
|
| 601 |
+
|
| 602 |
+
wall = time.time() - t0
|
| 603 |
+
|
| 604 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 605 |
+
np.savez_compressed(
|
| 606 |
+
RETINA_PATH,
|
| 607 |
+
sdr=sdr,
|
| 608 |
+
vocab_size=np.int64(vocab_size),
|
| 609 |
+
n_bits=np.int64(N_BITS),
|
| 610 |
+
grid_h=np.int64(GRID_H),
|
| 611 |
+
grid_w=np.int64(GRID_W),
|
| 612 |
+
target_active=np.int64(TARGET_ACTIVE),
|
| 613 |
+
context_window=np.int64(CONTEXT_WINDOW),
|
| 614 |
+
top_k_features=np.int64(TOP_K_FEATURES),
|
| 615 |
+
train_tokens=np.int64(total_tokens),
|
| 616 |
+
)
|
| 617 |
+
print(f"[save] wrote {RETINA_PATH} sdr.shape={sdr.shape} "
|
| 618 |
+
f"active_per_row={int(sdr.sum(axis=1).mean())} wall={wall:.1f}s")
|
| 619 |
+
|
| 620 |
+
# Push to HF Hub so subsequent jobs (and parallel retina experiments)
|
| 621 |
+
# skip the 500+ second build entirely.
|
| 622 |
+
_upload_retina_to_hub()
|
| 623 |
+
|
| 624 |
+
return BuildReport(
|
| 625 |
+
vocab_size=vocab_size,
|
| 626 |
+
n_bits=N_BITS,
|
| 627 |
+
train_tokens=total_tokens,
|
| 628 |
+
wall_time_sec=wall,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
|