Jackoatmon commited on
Commit
e317e25
·
verified ·
1 Parent(s): 7de795d

Update Feather h200 training runtime image

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +16 -20
  2. Dockerfile +118 -122
  3. entrypoint.py +267 -227
  4. mamba_ssm_init.py +62 -94
  5. overlay/.dockerignore +20 -20
  6. overlay/configs/__init__.py +5 -5
  7. overlay/configs/hardware_config.py +104 -104
  8. overlay/configs/harness_config.py +63 -63
  9. overlay/configs/model_config.py +80 -80
  10. overlay/harness/__init__.py +21 -21
  11. overlay/harness/eval_agent.py +129 -257
  12. overlay/harness/git_utils.py +94 -94
  13. overlay/harness/health_monitor.py +86 -86
  14. overlay/harness/meta_agent.py +139 -139
  15. overlay/harness/orchestrator.py +281 -284
  16. overlay/harness/search_strategy.py +153 -153
  17. overlay/htm_rust/Cargo.lock +383 -383
  18. overlay/htm_rust/Cargo.toml +37 -37
  19. overlay/htm_rust/build.rs +168 -160
  20. overlay/htm_rust/pyproject.toml +17 -17
  21. overlay/htm_rust/src/gpu/fused.rs +702 -663
  22. overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu +677 -677
  23. overlay/htm_rust/src/gpu/tests.rs +663 -643
  24. overlay/htm_rust/src/lib.rs +198 -198
  25. overlay/htm_rust/src/region.rs +94 -94
  26. overlay/htm_rust/src/sp.rs +302 -302
  27. overlay/htm_rust/src/tm.rs +545 -545
  28. overlay/hydra/__init__.py +37 -31
  29. overlay/hydra/config.py +225 -220
  30. overlay/hydra/data_module.py +288 -288
  31. overlay/hydra/diffusion_loss.py +236 -236
  32. overlay/hydra/engram.py +160 -175
  33. overlay/hydra/eval.py +210 -217
  34. overlay/hydra/gdn_block.py +126 -126
  35. overlay/hydra/hyena_block.py +68 -68
  36. overlay/hydra/lightning_module.py +326 -326
  37. overlay/hydra/model.py +0 -0
  38. overlay/hydra/optimizer.py +252 -252
  39. overlay/hydra/reality_bridge.py +71 -0
  40. overlay/hydra/training.py +965 -946
  41. overlay/kernels/cuda/decode_kernels.cu +10 -10
  42. overlay/kernels/cuda/flashfftconv/LICENSE +201 -201
  43. overlay/kernels/cuda/flashfftconv/README.md +57 -57
  44. overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT +1 -1
  45. overlay/kernels/cuda/flashfftconv/csrc/.gitignore +9 -9
  46. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h +373 -373
  47. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu +698 -698
  48. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu +724 -724
  49. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu +723 -723
  50. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu +705 -705
.dockerignore CHANGED
@@ -1,20 +1,16 @@
1
- .git
2
- .github
3
- .venv
4
- .remember
5
- .letta
6
- .claude
7
- __pycache__
8
- *.pyc
9
- *.pyo
10
- *.pyd
11
- *.log
12
- run_*.log
13
- run*.log
14
- *.txt
15
- WORKER_COMPLETE
16
- autoresearch_loop.log
17
- overlay/data/
18
- overlay/state_store/
19
- overlay/htm_rust/target/
20
- overlay/hydra-core/target/
 
1
+ # Keep HF runtime image context deterministic and small.
2
+ **/__pycache__/
3
+ **/*.py[cod]
4
+ **/.pytest_cache/
5
+ **/.mypy_cache/
6
+ **/.ruff_cache/
7
+ **/.venv/
8
+ **/target/
9
+ **/logs/
10
+ **/*.log
11
+ **/*.out
12
+ **/*.pt
13
+ **/*.safetensors
14
+ **/*.parquet
15
+ **/*.npz
16
+ **/.git/
 
 
 
 
Dockerfile CHANGED
@@ -1,128 +1,124 @@
1
- FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
 
3
- ARG HTM_CUDA_ARCH=sm_86
4
-
5
- ENV DEBIAN_FRONTEND=noninteractive \
6
- PIP_NO_CACHE_DIR=1 \
7
- PYTHONUNBUFFERED=1 \
8
- CARGO_HOME=/root/.cargo \
9
- RUSTUP_HOME=/root/.rustup \
10
- PATH=/root/.cargo/bin:${PATH}
11
-
12
- RUN apt-get update && apt-get install -y --no-install-recommends \
13
- git curl ca-certificates build-essential pkg-config libssl-dev && \
14
- rm -rf /var/lib/apt/lists/*
15
-
16
- RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
17
-
18
- RUN pip install --upgrade pip setuptools wheel && \
19
- pip install \
20
- maturin \
21
- huggingface_hub \
22
- datasets \
23
- requests \
24
- pyarrow \
25
- rustbpe \
26
- pandas \
27
- tiktoken \
28
- pydantic \
29
- ninja \
30
- packaging \
31
- einops
32
-
33
- # Mamba-3 fused CUDA kernel stack (mandatory — NO fallback allowed).
34
- #
35
- # We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
36
- # and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
37
- # on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 —
38
- # nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
39
- #
40
- # Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
41
- # - Python 3.11 (cp311) — matches PyTorch 2.6.0 image
42
- # - CUDA 12.x wheels (cu12) matches host CUDA 12.4
43
- # - PyTorch 2.6 ABI (torch2.6) — exact torch match
44
- # - cxx11abiFALSE — standard PyTorch pip build
45
- #
46
- # Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
47
- # 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
48
- # on the Space builder.
49
- #
50
- # Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
51
- # for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
52
- RUN pip install \
53
- 'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
54
- 'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
55
- python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
56
-
57
- #
58
- # Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
59
- # main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
60
- # files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
61
- # zero compiled-CUDA dependencies (verified: every import in that subtree is
62
- # triton/torch/python — no .so files, no nvcc). So we install the v2.3.1 wheel
63
- # (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
64
- #
65
- # This avoids the source-build OOM on the cpu-basic HF Space builder and the
66
- # missing-file error the smoke hit on the last attempt.
67
- # Download grafted mamba3 module + triton ops subtree
68
- RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
69
- BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
70
- curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
71
- mkdir -p "$SITE/ops/triton/mamba3" && \
72
- for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
73
- curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
74
- done
75
-
76
- # Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
77
- # (pure-Triton, works). The shipped __init__.py eagerly imports
78
- # selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
79
- # image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
80
- # Mamba3 (grafted from main), we skip all compiled-CUDA imports.
81
- COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
82
-
83
- # Structural check (no triton init triton has no GPU on the builder)
84
- RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
85
- test -f "$SITE/modules/mamba3.py" && \
86
- test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
87
- test -s "$SITE/__init__.py" && \
88
- echo "mamba3 graft + __init__ override verified"
89
-
90
- # Optional tilelang for MIMO path — pure-python, cheap; SISO Mamba3 works without.
91
- RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed — continuing"
92
-
93
- # Triton version decision: FORCE 3.5.1 — the only version with both mamba3
94
- # APIs (set_allocator + tl.make_tensor_descriptor). torch 2.6's _inductor
95
- # imports AttrsDescriptor from triton.compiler.compiler which was removed in
96
- # triton 3.4+, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
97
- # before any torch._inductor import path runs, so the incompatibility is
98
- # neutralized. Build-time assert verifies mamba3's two required APIs.
99
- RUN pip install --force-reinstall --no-deps 'triton==3.5.1' && \
100
- python -c "import triton; from triton import language as tl; \
101
- assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
102
- assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
103
- print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"
104
-
105
- WORKDIR /workspace
106
- COPY overlay /workspace/feather
107
- COPY entrypoint.py /app/entrypoint.py
108
- WORKDIR /workspace/feather
109
-
110
- RUN python - <<'PY'
111
- from pathlib import Path
112
- for sh in Path('/workspace/feather/scripts').glob('*.sh'):
113
- raw = sh.read_bytes()
114
- norm = raw.replace(b'\r\n', b'\n')
115
- if norm != raw:
116
- sh.write_bytes(norm)
117
- PY
118
 
119
  RUN python -m py_compile hydra/training.py prepare.py train.py && \
120
  bash -n scripts/run_domain_expanded_pretrain.sh
121
-
122
  RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
123
- export HTM_CUDA_ARCH=${HTM_CUDA_ARCH} && \
124
- export CARGO_BUILD_JOBS=1 && \
125
- maturin build --release -j 1 --features gpu --manifest-path htm_rust/Cargo.toml && \
126
  pip install htm_rust/target/wheels/htm_rust-*.whl
127
-
128
- CMD ["python", "/app/entrypoint.py"]
 
1
+ FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel
2
 
3
+ # Default target is HF Jobs a10g-large (NVIDIA A10G, Ampere GA102, sm_86).
4
+ # Override at build time for other cards, e.g. --build-arg FEATHER_GPU_ARCH=sm_90a.
5
+ ARG FEATHER_GPU_ARCH=sm_86
6
+ ARG FEATHER_TORCH_CUDA_ARCH_LIST=8.6
7
+
8
+ ENV DEBIAN_FRONTEND=noninteractive \
9
+ PIP_NO_CACHE_DIR=1 \
10
+ PYTHONUNBUFFERED=1 \
11
+ CARGO_HOME=/root/.cargo \
12
+ RUSTUP_HOME=/root/.rustup \
13
+ HTM_CUDA_ARCH=${FEATHER_GPU_ARCH} \
14
+ TORCH_CUDA_ARCH_LIST=${FEATHER_TORCH_CUDA_ARCH_LIST} \
15
+ PATH=/root/.cargo/bin:${PATH}
16
+
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ git curl ca-certificates build-essential pkg-config libssl-dev && \
19
+ rm -rf /var/lib/apt/lists/*
20
+
21
+ RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
22
+
23
+ RUN pip install --upgrade pip setuptools wheel && \
24
+ pip install \
25
+ maturin \
26
+ huggingface_hub \
27
+ datasets \
28
+ requests \
29
+ pyarrow \
30
+ rustbpe \
31
+ pandas \
32
+ tiktoken \
33
+ pydantic \
34
+ ninja \
35
+ packaging \
36
+ einops
37
+
38
+ # Mamba-3 fused CUDA kernel stack (mandatory NO fallback allowed).
39
+ #
40
+ # We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
41
+ # and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
42
+ # on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 —
43
+ # nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
44
+ #
45
+ # Wheel selection for base image pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel:
46
+ # - Python 3.11 (cp311) — matches PyTorch 2.5.1 image
47
+ # - CUDA 12.x wheels (cu12) compatible with CUDA 12.1 base
48
+ # - PyTorch 2.5 ABI (torch2.5) — exact torch match
49
+ # - cxx11abiFALSE — standard PyTorch pip build
50
+ #
51
+ # Versions: mamba_ssm 2.3.0 + causal_conv1d 1.6.0 (matching torch2.5 ABI).
52
+ # Both are CUDA-compiled, no build toolchain needed
53
+ # on the Space builder.
54
+ #
55
+ # Step A: install the published v2.3.0 prebuilt wheel (compiled CUDA ops
56
+ # for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
57
+ RUN pip install \
58
+ 'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.0/causal_conv1d-1.6.0+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
59
+ 'https://github.com/state-spaces/mamba/releases/download/v2.3.0/mamba_ssm-2.3.0+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
60
+ python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
61
+
62
+ #
63
+ # Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
64
+ # main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
65
+ # files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
66
+ # zero compiled-CUDA dependencies (verified: every import in that subtree is
67
+ # triton/torch/python no .so files, no nvcc). So we install the v2.3.1 wheel
68
+ # (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
69
+ #
70
+ # This avoids the source-build OOM on the cpu-basic HF Space builder and the
71
+ # missing-file error the smoke hit on the last attempt.
72
+ # Download grafted mamba3 module + triton ops subtree
73
+ RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
74
+ BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
75
+ curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
76
+ mkdir -p "$SITE/ops/triton/mamba3" && \
77
+ for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
78
+ curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
79
+ done
80
+
81
+ # Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
82
+ # (pure-Triton, works). The shipped __init__.py eagerly imports
83
+ # selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
84
+ # image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
85
+ # Mamba3 (grafted from main), we skip all compiled-CUDA imports.
86
+ COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
87
+
88
+ # Structural check (no triton init — triton has no GPU on the builder)
89
+ RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
90
+ test -f "$SITE/modules/mamba3.py" && \
91
+ test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
92
+ test -s "$SITE/__init__.py" && \
93
+ echo "mamba3 graft + __init__ override verified"
94
+
95
+ # Optional tilelang for MIMO path pure-python, cheap; SISO Mamba3 works without.
96
+ RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed — continuing"
97
+
98
+ # Triton version decision: FORCE 3.4.0 first line with both mamba3
99
+ # APIs (set_allocator + tl.make_tensor_descriptor) while avoiding the 3.5.x
100
+ # driver-discovery regression seen on HF A10G (`0 active drivers` despite
101
+ # torch.cuda being available). torch 2.5's _inductor expects older Triton
102
+ # internals, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
103
+ # before any torch._inductor import path runs, so the incompatibility is
104
+ # neutralized. Build-time assert verifies mamba3's two required APIs.
105
+ RUN pip install --force-reinstall --no-deps 'triton==3.4.0' && \
106
+ python -c "import triton; from triton import language as tl; \
107
+ assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
108
+ assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
109
+ print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"
110
+
111
+ WORKDIR /workspace
112
+ COPY overlay /workspace/feather
113
+ COPY entrypoint.py /app/entrypoint.py
114
+ WORKDIR /workspace/feather
 
 
 
115
 
116
  RUN python -m py_compile hydra/training.py prepare.py train.py && \
117
  bash -n scripts/run_domain_expanded_pretrain.sh
118
+
119
  RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
120
+ echo "building htm_rust GPU kernels for HTM_CUDA_ARCH=${HTM_CUDA_ARCH} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}" && \
121
+ maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
 
122
  pip install htm_rust/target/wheels/htm_rust-*.whl
123
+
124
+ CMD ["python", "/app/entrypoint.py"]
entrypoint.py CHANGED
@@ -1,227 +1,267 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
- import os
6
- import subprocess
7
- import sys
8
- import time
9
- from http.server import BaseHTTPRequestHandler, HTTPServer
10
- from pathlib import Path
11
- from threading import Thread
12
-
13
-
14
- # =============================================================================
15
- # EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
16
- # =============================================================================
17
- # On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet
18
- # initialized" on first use, because nvidia-fabricmanager on the host
19
- # synchronizes with the container's first driver call. Once any NVML/CUDA
20
- # call succeeds once (even just nvidia-smi), the fabric is up for the rest
21
- # of the container lifetime.
22
- #
23
- # Our previous approach (wait in a subprocess before training) didn't work
24
- # because the "initialization failed" state persisted across calls in the
25
- # same container. The real fix: kick the driver exactly once with
26
- # nvidia-smi, which is what successfully-working baseline containers do
27
- # implicitly via their first torch.cuda call.
28
- #
29
- # Must happen BEFORE `import torch` (because any import that eagerly calls
30
- # cudaGetDeviceCount will cache the Error 802 state).
31
- def _early_cuda_kick() -> None:
32
- deadline = time.time() + 120.0
33
- attempt = 0
34
- while time.time() < deadline:
35
- attempt += 1
36
- r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
37
- if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \
38
- or 'A100' in (r.stdout or '') or r.returncode == 0:
39
- print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True)
40
- break
41
- print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
42
- flush=True)
43
- time.sleep(2)
44
- # After nvidia-smi, probe torch in a subprocess so any latent error state
45
- # doesn't leak into the main process's CUDA context.
46
- probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
47
- torch_deadline = time.time() + 120.0
48
- t_attempt = 0
49
- while time.time() < torch_deadline:
50
- t_attempt += 1
51
- r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
52
- if r.returncode == 0:
53
- print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
54
- return
55
- if t_attempt == 1:
56
- print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
57
- time.sleep(2)
58
- print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
59
-
60
-
61
- _early_cuda_kick()
62
-
63
- # Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
64
- # triton_cache_setup.py is copied next to this file by the job bash command.
65
- try:
66
- import triton_cache_setup as _tcs
67
- _tcs.setup()
68
- except ImportError:
69
- print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
70
-
71
- from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
72
-
73
- REPO_ROOT = Path('/workspace/feather')
74
- CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
75
- LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
76
- JOB_ID = os.environ.get('JOB_ID', 'local-job')
77
- OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
78
- TOKEN = os.environ.get('HF_TOKEN')
79
- RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
80
- APP_PORT = int(os.environ.get('PORT', '7860'))
81
-
82
-
83
- class _HealthHandler(BaseHTTPRequestHandler):
84
- def do_GET(self):
85
- if self.path in ('/', '/health', '/healthz', '/ready'):
86
- payload = {
87
- 'status': 'ok',
88
- 'mode': RUNTIME_MODE,
89
- 'job_id': JOB_ID,
90
- }
91
- body = json.dumps(payload).encode('utf-8')
92
- self.send_response(200)
93
- self.send_header('Content-Type', 'application/json')
94
- self.send_header('Content-Length', str(len(body)))
95
- self.end_headers()
96
- self.wfile.write(body)
97
- return
98
- self.send_response(404)
99
- self.end_headers()
100
-
101
- def log_message(self, format, *args):
102
- return
103
-
104
-
105
- def _start_health_server() -> HTTPServer:
106
- server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
107
- thread = Thread(target=server.serve_forever, daemon=True)
108
- thread.start()
109
- print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
110
- return server
111
-
112
-
113
- def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
114
- if not path.exists():
115
- print(f'[upload] skip missing {path}', flush=True)
116
- return
117
- api.upload_file(
118
- path_or_fileobj=str(path),
119
- path_in_repo=dest,
120
- repo_id=OUTPUT_REPO,
121
- repo_type='model',
122
- )
123
- print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
124
-
125
-
126
- def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
127
- """Block until CUDA is fully initialized or timeout.
128
-
129
- On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
130
- with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
131
- (error 802) for the first few seconds, and any import that triggers
132
- @triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
133
- "0 active drivers" if it happens during that window.
134
-
135
- We pre-init CUDA in a throwaway Python subprocess (so any error state does
136
- not leak into the main training process) and retry until torch.cuda
137
- reports ready.
138
- """
139
- import time as _t
140
- probe = (
141
- "import torch; "
142
- "import sys; "
143
- "avail = torch.cuda.is_available(); "
144
- "count = torch.cuda.device_count() if avail else 0; "
145
- "sys.exit(0 if (avail and count > 0) else 1)"
146
- )
147
- deadline = _t.time() + timeout_s
148
- attempt = 0
149
- while _t.time() < deadline:
150
- attempt += 1
151
- r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
152
- if r.returncode == 0:
153
- print(f'[job] CUDA ready after {attempt} probe(s)', flush=True)
154
- return
155
- if attempt == 1:
156
- print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
157
- _t.sleep(2)
158
- print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
159
-
160
-
161
- def run_job_mode() -> int:
162
- os.chdir(REPO_ROOT)
163
- os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
164
- os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
165
- os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
166
- os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
167
- os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
168
-
169
- # CUDA readiness was kicked at module import via _early_cuda_kick. Keep
170
- # the wait as a second safety net — no-op if CUDA already ready.
171
- _wait_for_cuda_ready()
172
-
173
- cmd = [
174
- 'bash',
175
- './scripts/run_domain_expanded_pretrain.sh',
176
- '--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
177
- '--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
178
- ]
179
- print('[job] starting Feather domain-expanded pretrain', flush=True)
180
- print(f'[job] command={cmd}', flush=True)
181
- proc = subprocess.run(cmd, check=False)
182
-
183
- # Push triton compilation cache back to HF Hub for next run.
184
- try:
185
- import triton_cache_setup as _tcs
186
- _tcs.teardown()
187
- except Exception as _tcs_err:
188
- print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
189
-
190
- if TOKEN:
191
- api = HfApi(token=TOKEN)
192
- try:
193
- api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
194
- except Exception as e:
195
- print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
196
- prefix = f'jobs/{JOB_ID}'
197
- try:
198
- upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
199
- upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
200
- upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
201
- except Exception as e:
202
- print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
203
- else:
204
- print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
205
-
206
- return proc.returncode
207
-
208
-
209
- def run_space_mode() -> int:
210
- server = _start_health_server()
211
- print('[space] Feather runtime image ready', flush=True)
212
- try:
213
- while True:
214
- time.sleep(3600)
215
- finally:
216
- server.shutdown()
217
- server.server_close()
218
-
219
-
220
- def main() -> int:
221
- if RUNTIME_MODE == 'job':
222
- return run_job_mode()
223
- return run_space_mode()
224
-
225
-
226
- if __name__ == '__main__':
227
- raise SystemExit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import time
9
+ from http.server import BaseHTTPRequestHandler, HTTPServer
10
+ from pathlib import Path
11
+ from threading import Thread
12
+
13
+
14
+ def _prepend_library_path(*paths: str) -> None:
15
+ """Expose injected NVIDIA driver libraries before torch/triton imports."""
16
+ existing = [p for p in os.environ.get('LD_LIBRARY_PATH', '').split(':') if p]
17
+ merged = []
18
+ for p in paths:
19
+ if p and p not in merged:
20
+ merged.append(p)
21
+ for p in existing:
22
+ if p not in merged:
23
+ merged.append(p)
24
+ os.environ['LD_LIBRARY_PATH'] = ':'.join(merged)
25
+
26
+
27
+ _prepend_library_path(
28
+ # HF Jobs injects the host driver under /usr/local/nvidia. Prefer that
29
+ # over CUDA toolkit/compat libcuda stubs; using /usr/local/cuda/compat here
30
+ # made A10G PyTorch report Error 803 despite nvidia-smi working.
31
+ '/usr/local/nvidia/lib64',
32
+ '/usr/local/nvidia/lib',
33
+ '/usr/lib/x86_64-linux-gnu',
34
+ )
35
+
36
+
37
+ # =============================================================================
38
+ # EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
39
+ # =============================================================================
40
+ # On HF GPU hosts, cudaGetDeviceCount can transiently return not-ready errors
41
+ # on first use. H200 fabric-manager is the worst case; A10G is usually ready
42
+ # immediately, but the same early kick keeps the runtime deterministic.
43
+ # synchronizes with the container's first driver call. Once any NVML/CUDA
44
+ # call succeeds once (even just nvidia-smi), the fabric is up for the rest
45
+ # of the container lifetime.
46
+ #
47
+ # Our previous approach (wait in a subprocess before training) didn't work
48
+ # because the "initialization failed" state persisted across calls in the
49
+ # same container. The real fix: kick the driver exactly once with
50
+ # nvidia-smi, which is what successfully-working baseline containers do
51
+ # implicitly via their first torch.cuda call.
52
+ #
53
+ # Must happen BEFORE `import torch` (because any import that eagerly calls
54
+ # cudaGetDeviceCount will cache the Error 802 state).
55
+ def _early_cuda_kick() -> None:
56
+ deadline = time.time() + 120.0
57
+ attempt = 0
58
+ while time.time() < deadline:
59
+ attempt += 1
60
+ r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
61
+ if r.returncode == 0:
62
+ gpu_line = next((ln.strip() for ln in (r.stdout or '').splitlines() if any(g in ln for g in ('A10', 'A100', 'H100', 'H200', 'RTX'))), 'gpu=unknown')
63
+ print(f'[boot] nvidia-smi OK on attempt {attempt}: {gpu_line}', flush=True)
64
+ break
65
+ print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
66
+ flush=True)
67
+ time.sleep(2)
68
+ # After nvidia-smi, probe torch in a subprocess so any latent error state
69
+ # doesn't leak into the main process's CUDA context.
70
+ probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
71
+ torch_deadline = time.time() + 120.0
72
+ t_attempt = 0
73
+ while time.time() < torch_deadline:
74
+ t_attempt += 1
75
+ r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
76
+ if r.returncode == 0:
77
+ print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
78
+ return
79
+ if t_attempt == 1:
80
+ print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
81
+ time.sleep(2)
82
+ print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
83
+
84
+
85
+ _early_cuda_kick()
86
+
87
+ # Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
88
+ # triton_cache_setup.py is copied next to this file by the job bash command.
89
+ try:
90
+ import triton_cache_setup as _tcs
91
+ _tcs.setup()
92
+ except ImportError:
93
+ print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
94
+
95
+ from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
96
+
97
+ REPO_ROOT = Path('/workspace/feather')
98
+ CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
99
+ LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
100
+ JOB_ID = os.environ.get('JOB_ID', 'local-job')
101
+ OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
102
+ TOKEN = os.environ.get('HF_TOKEN')
103
+ RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
104
+ APP_PORT = int(os.environ.get('PORT', '7860'))
105
+
106
+
107
+ class _HealthHandler(BaseHTTPRequestHandler):
108
+ def do_GET(self):
109
+ if self.path in ('/', '/health', '/healthz', '/ready'):
110
+ payload = {
111
+ 'status': 'ok',
112
+ 'mode': RUNTIME_MODE,
113
+ 'job_id': JOB_ID,
114
+ }
115
+ body = json.dumps(payload).encode('utf-8')
116
+ self.send_response(200)
117
+ self.send_header('Content-Type', 'application/json')
118
+ self.send_header('Content-Length', str(len(body)))
119
+ self.end_headers()
120
+ self.wfile.write(body)
121
+ return
122
+ self.send_response(404)
123
+ self.end_headers()
124
+
125
+ def log_message(self, format, *args):
126
+ return
127
+
128
+
129
+ def _start_health_server() -> HTTPServer:
130
+ server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
131
+ thread = Thread(target=server.serve_forever, daemon=True)
132
+ thread.start()
133
+ print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
134
+ return server
135
+
136
+
137
+ def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
138
+ if not path.exists():
139
+ print(f'[upload] skip missing {path}', flush=True)
140
+ return
141
+ api.upload_file(
142
+ path_or_fileobj=str(path),
143
+ path_in_repo=dest,
144
+ repo_id=OUTPUT_REPO,
145
+ repo_type='model',
146
+ )
147
+ print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
148
+
149
+
150
+ def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
151
+ """Block until CUDA is fully initialized or timeout.
152
+
153
+ On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
154
+ with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
155
+ (error 802) for the first few seconds, and any import that triggers
156
+ @triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
157
+ "0 active drivers" if it happens during that window.
158
+
159
+ We pre-init CUDA in a throwaway Python subprocess (so any error state does
160
+ not leak into the main training process) and retry until torch.cuda
161
+ reports ready.
162
+ """
163
+ import time as _t
164
+ probe = (
165
+ "import torch; "
166
+ "import sys; "
167
+ "avail = torch.cuda.is_available(); "
168
+ "count = torch.cuda.device_count() if avail else 0; "
169
+ "torch.empty(1, device='cuda') if (avail and count > 0) else None; "
170
+ "from triton.runtime import driver; "
171
+ "driver.active.get_current_device(); "
172
+ "sys.exit(0 if (avail and count > 0) else 1)"
173
+ )
174
+ deadline = _t.time() + timeout_s
175
+ attempt = 0
176
+ while _t.time() < deadline:
177
+ attempt += 1
178
+ r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
179
+ if r.returncode == 0:
180
+ print(f'[job] CUDA/Triton ready after {attempt} probe(s)', flush=True)
181
+ return
182
+ if attempt == 1:
183
+ print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
184
+ _t.sleep(2)
185
+ print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
186
+
187
+
188
+ def run_job_mode() -> int:
189
+ os.chdir(REPO_ROOT)
190
+ os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
191
+ os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
192
+ os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
193
+ os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
194
+ os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
195
+ os.environ.setdefault('FEATHER_GPU_PROFILE', 'a10g-large')
196
+ os.environ.setdefault('HTM_CUDA_ARCH', 'sm_86')
197
+ os.environ.setdefault('TORCH_CUDA_ARCH_LIST', '8.6')
198
+ os.environ.setdefault('TRITON_CACHE_DIR', f"/workspace/triton_cache/{os.environ['FEATHER_GPU_PROFILE']}")
199
+ os.environ.setdefault('TRITON_CACHE_REPO', f"icarus112/feather-triton-cache-{os.environ['FEATHER_GPU_PROFILE']}")
200
+ print(f"[job] gpu_profile={os.environ['FEATHER_GPU_PROFILE']} htm_cuda_arch={os.environ['HTM_CUDA_ARCH']} torch_cuda_arch={os.environ['TORCH_CUDA_ARCH_LIST']}", flush=True)
201
+
202
+ # CUDA readiness was kicked at module import via _early_cuda_kick. Keep
203
+ # the wait as a second safety net — no-op if CUDA already ready.
204
+ _wait_for_cuda_ready()
205
+
206
+ cmd = [
207
+ 'bash',
208
+ './scripts/run_domain_expanded_pretrain.sh',
209
+ '--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
210
+ '--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
211
+ ]
212
+ print('[job] ensuring retina.npz before training...', flush=True)
213
+ try:
214
+ sys.path.insert(0, str(REPO_ROOT))
215
+ from subsystems.sdr_retina import build_retina
216
+ build_retina()
217
+ except Exception as _retina_err:
218
+ print(f'[job] retina bootstrap warning (train.py may still build it): {_retina_err}', flush=True)
219
+ print('[job] starting Feather domain-expanded pretrain', flush=True)
220
+ print(f'[job] command={cmd}', flush=True)
221
+ proc = subprocess.run(cmd, check=False)
222
+
223
+ # Push triton compilation cache back to HF Hub for next run.
224
+ try:
225
+ import triton_cache_setup as _tcs
226
+ _tcs.teardown()
227
+ except Exception as _tcs_err:
228
+ print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
229
+
230
+ if TOKEN:
231
+ api = HfApi(token=TOKEN)
232
+ try:
233
+ api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
234
+ except Exception as e:
235
+ print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
236
+ prefix = f'jobs/{JOB_ID}'
237
+ try:
238
+ upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
239
+ upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
240
+ upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
241
+ except Exception as e:
242
+ print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
243
+ else:
244
+ print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
245
+
246
+ return proc.returncode
247
+
248
+
249
+ def run_space_mode() -> int:
250
+ server = _start_health_server()
251
+ print('[space] Feather runtime image ready', flush=True)
252
+ try:
253
+ while True:
254
+ time.sleep(3600)
255
+ finally:
256
+ server.shutdown()
257
+ server.server_close()
258
+
259
+
260
+ def main() -> int:
261
+ if RUNTIME_MODE == 'job':
262
+ return run_job_mode()
263
+ return run_space_mode()
264
+
265
+
266
+ if __name__ == '__main__':
267
+ raise SystemExit(main())
mamba_ssm_init.py CHANGED
@@ -1,101 +1,69 @@
1
- # mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
2
- # ABI mismatch with the base image's libtorch.
3
- #
4
- # The upstream __init__.py eagerly imports selective_scan_cuda which fails on
5
- # pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
6
- # symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
7
- # all compiled-CUDA imports here and let Mamba3 load directly.
8
-
9
- __version__ = "2.3.1+feather-graft"
10
-
11
- # selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
12
- # by the Feather training path (which is Mamba3-only). If any import path
13
- # hits this, it will get a clear AttributeError instead of an obscure ImportError.
14
- selective_scan_fn = None
15
- mamba_inner_fn = None
16
-
17
- # --- triton API compatibility shims -----------------------------------------
18
- # Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
19
- # imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
20
- # Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
21
- # tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
22
- # satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
23
- # APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
24
- # stub is never actually invoked at runtime because the codebase does not use
25
- # torch.compile — but importing torch._inductor.* still requires the symbol to
26
- # exist at module load time.
27
  import triton as _triton # noqa: E402
28
  if not hasattr(_triton, "set_allocator"):
29
- def _noop_set_allocator(_fn): # pragma: no cover
30
- return None
31
- _triton.set_allocator = _noop_set_allocator
32
-
33
- import triton.compiler.compiler as _tcc # noqa: E402
34
- if not hasattr(_tcc, "AttrsDescriptor"):
35
- class _AttrsDescriptorShim:
36
- """Stub for torch._inductor compatibility on triton >= 3.4.
37
- torch._inductor.runtime.hints imports this at module load but the
38
- constructor is only called inside torch.compile paths. Accept any
39
- args/kwargs so the import itself succeeds."""
40
- def __init__(self, *args, **kwargs):
41
- self.args = args
42
- self.kwargs = kwargs
43
-
44
- @classmethod
45
- def from_hints(cls, *args, **kwargs):
46
- return cls(*args, **kwargs)
47
-
48
- _tcc.AttrsDescriptor = _AttrsDescriptorShim
49
-
50
- # triton_key: removed in triton 3.5, used by torch._inductor.codecache for
51
- # FxGraphCache key derivation. Return a stable string so caching still works.
52
- if not hasattr(_tcc, "triton_key"):
53
- def _triton_key_shim():
54
- import triton as _t
55
- return f"triton-{_t.__version__}-shim"
56
- _tcc.triton_key = _triton_key_shim
57
 
58
- # Triton 3.5 wheels can occasionally load with an empty backend registry in
59
- # HF Jobs environments (driver.active -> "0 active drivers"), even though the
60
- # NVIDIA backend module is present and CudaDriver.is_active() is True.
61
- # Patch _create_driver to directly select CudaDriver when registry discovery
62
- # returns empty.
63
- import importlib as _importlib # noqa: E402
64
- _triton_driver_mod = _importlib.import_module("triton.runtime.driver")
65
- if getattr(_triton_driver_mod, "backends", None) == {}:
66
- from triton.backends.nvidia import driver as _nvidia_driver # noqa: E402
67
 
68
- def _create_driver_shim():
69
- if hasattr(_nvidia_driver, "CudaDriver") and _nvidia_driver.CudaDriver.is_active():
70
- return _nvidia_driver.CudaDriver()
71
- raise RuntimeError(
72
- "Triton backend registry is empty and NVIDIA CudaDriver is not active"
73
- )
74
 
75
- _triton_driver_mod._create_driver = _create_driver_shim
76
- if hasattr(_triton_driver_mod, "driver") and hasattr(_triton_driver_mod.driver, "reset_active"):
77
- _triton_driver_mod.driver.reset_active()
 
 
 
 
78
 
79
- _triton_compiler_mod = _importlib.import_module("triton.compiler.compiler")
80
- if getattr(_triton_compiler_mod, "backends", None) == {}:
81
- from triton.backends import Backend as _Backend # noqa: E402
82
- from triton.backends.nvidia.compiler import CUDABackend as _CUDABackend # noqa: E402
83
- from triton.backends.nvidia.driver import CudaDriver as _CudaDriver # noqa: E402
 
 
 
 
84
 
85
- _triton_compiler_mod.backends["nvidia"] = _Backend(
86
- compiler=_CUDABackend,
87
- driver=_CudaDriver,
88
- )
89
-
90
- # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
91
- # for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
92
- # so fall back to eager on any dynamo failure rather than crashing. This is
93
- # defense-in-depth against further triton API drift.
94
- try:
95
- import torch._dynamo # noqa: F401 — triggers dynamo module init
96
- torch._dynamo.config.suppress_errors = True
97
- except Exception: # pragma: no cover
98
- pass
99
-
100
- # Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
101
- from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
 
1
+ # mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
2
+ # ABI mismatch with the base image's libtorch.
3
+ #
4
+ # The upstream __init__.py eagerly imports selective_scan_cuda which fails on
5
+ # pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
6
+ # symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
7
+ # all compiled-CUDA imports here and let Mamba3 load directly.
8
+
9
+ __version__ = "2.3.1+feather-graft"
10
+
11
+ # selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
12
+ # by the Feather training path (which is Mamba3-only). If any import path
13
+ # hits this, it will get a clear AttributeError instead of an obscure ImportError.
14
+ selective_scan_fn = None
15
+ mamba_inner_fn = None
16
+
17
+ # --- triton API compatibility shims -----------------------------------------
18
+ # Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
19
+ # imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
20
+ # Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
21
+ # tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
22
+ # satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
23
+ # APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
24
+ # stub is never actually invoked at runtime because the codebase does not use
25
+ # torch.compile — but importing torch._inductor.* still requires the symbol to
26
+ # exist at module load time.
27
  import triton as _triton # noqa: E402
28
  if not hasattr(_triton, "set_allocator"):
29
+ def _noop_set_allocator(_fn): # pragma: no cover
30
+ return None
31
+ _triton.set_allocator = _noop_set_allocator
32
+
33
+ import triton.compiler.compiler as _tcc # noqa: E402
34
+ if not hasattr(_tcc, "AttrsDescriptor"):
35
+ class _AttrsDescriptorShim:
36
+ """Stub for torch._inductor compatibility on triton >= 3.4.
37
+ torch._inductor.runtime.hints imports this at module load but the
38
+ constructor is only called inside torch.compile paths. Accept any
39
+ args/kwargs so the import itself succeeds."""
40
+ def __init__(self, *args, **kwargs):
41
+ self.args = args
42
+ self.kwargs = kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ @classmethod
45
+ def from_hints(cls, *args, **kwargs):
46
+ return cls(*args, **kwargs)
 
 
 
 
 
 
47
 
48
+ _tcc.AttrsDescriptor = _AttrsDescriptorShim
 
 
 
 
 
49
 
50
+ # triton_key: removed in triton 3.5, used by torch._inductor.codecache for
51
+ # FxGraphCache key derivation. Return a stable string so caching still works.
52
+ if not hasattr(_tcc, "triton_key"):
53
+ def _triton_key_shim():
54
+ import triton as _t
55
+ return f"triton-{_t.__version__}-shim"
56
+ _tcc.triton_key = _triton_key_shim
57
 
58
+ # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
59
+ # for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
60
+ # so fall back to eager on any dynamo failure rather than crashing. This is
61
+ # defense-in-depth against further triton API drift.
62
+ try:
63
+ import torch._dynamo # noqa: F401 — triggers dynamo module init
64
+ torch._dynamo.config.suppress_errors = True
65
+ except Exception: # pragma: no cover
66
+ pass
67
 
68
+ # Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
69
+ from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/.dockerignore CHANGED
@@ -1,20 +1,20 @@
1
- .git
2
- .github
3
- .venv
4
- .remember
5
- .letta
6
- .claude
7
- __pycache__
8
- *.pyc
9
- *.pyo
10
- *.pyd
11
- *.log
12
- run_*.log
13
- run*.log
14
- *.txt
15
- WORKER_COMPLETE
16
- autoresearch_loop.log
17
- data/
18
- state_store/
19
- htm_rust/target/
20
- hydra-core/target/
 
1
+ .git
2
+ .github
3
+ .venv
4
+ .remember
5
+ .letta
6
+ .claude
7
+ __pycache__
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ *.log
12
+ run_*.log
13
+ run*.log
14
+ *.txt
15
+ WORKER_COMPLETE
16
+ autoresearch_loop.log
17
+ data/
18
+ state_store/
19
+ htm_rust/target/
20
+ hydra-core/target/
overlay/configs/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from configs.hardware_config import HardwareConfig
2
- from configs.harness_config import HarnessConfig
3
- from configs.model_config import PostSemClawConfig
4
-
5
- __all__ = ["PostSemClawConfig", "HarnessConfig", "HardwareConfig"]
 
1
+ from configs.hardware_config import HardwareConfig
2
+ from configs.harness_config import HarnessConfig
3
+ from configs.model_config import PostSemClawConfig
4
+
5
+ __all__ = ["PostSemClawConfig", "HarnessConfig", "HardwareConfig"]
overlay/configs/hardware_config.py CHANGED
@@ -1,104 +1,104 @@
1
- """Hardware detection and memory budget configuration."""
2
- from __future__ import annotations
3
-
4
- import torch
5
- from pydantic import BaseModel, Field
6
-
7
-
8
- class HardwareConfig(BaseModel):
9
- """Auto-detected hardware configuration with memory budgets."""
10
-
11
- gpu_name: str = Field(default="unknown", description="GPU device name")
12
- gpu_memory_mb: int = Field(default=0, description="Total GPU memory in MB")
13
- gpu_vram_mb: int = Field(default=0, description="Alias for gpu_memory_mb (legacy compat)")
14
- compute_capability: tuple[int, int] = Field(
15
- default=(0, 0), description="CUDA compute capability"
16
- )
17
- peak_flops: float = Field(
18
- default=12.74e12, description="Peak FP32 FLOPS for MFU calculation"
19
- )
20
- bf16_peak_flops: float = Field(
21
- default=38.1e12, description="Peak BF16 FLOPS (RTX 3060 default)"
22
- )
23
-
24
- # Memory budget
25
- model_budget_mb: int = Field(
26
- default=1500, description="Max MB for model params + optimizer"
27
- )
28
- activation_budget_mb: int = Field(
29
- default=3000, description="Max MB for activations"
30
- )
31
- overhead_mb: int = Field(
32
- default=500, description="Reserved for CUDA context + PyTorch overhead"
33
- )
34
- max_vram_usage_pct: float = Field(
35
- default=90.0, description="Max VRAM usage as % of total"
36
- )
37
- gradient_checkpointing: bool = Field(
38
- default=False, description="Enable gradient checkpointing to save VRAM"
39
- )
40
-
41
- @classmethod
42
- def detect(cls) -> HardwareConfig:
43
- """Auto-detect hardware from current CUDA device."""
44
- if not torch.cuda.is_available():
45
- return cls()
46
-
47
- device = torch.cuda.current_device()
48
- props = torch.cuda.get_device_properties(device)
49
- cap = (props.major, props.minor)
50
- mem_mb = props.total_memory // (1024 * 1024)
51
- gpu_name = props.name
52
-
53
- # Peak FP32 FLOPS lookup by compute capability (approximate)
54
- fp32_flops_table: dict[tuple[int, int], float] = {
55
- (8, 6): 12.74e12, # RTX 3060
56
- (8, 9): 40.09e12, # RTX 4090
57
- (9, 0): 989.5e12, # H100 (BF16)
58
- }
59
- peak = fp32_flops_table.get(cap, 12.74e12)
60
-
61
- # BF16 peak FLOPS lookup by GPU name substring
62
- bf16_flops_table: dict[str, float] = {
63
- "3060": 38.1e12,
64
- "3090": 71.0e12,
65
- "4090": 165.2e12,
66
- "A100": 312e12,
67
- "H100": 989.5e12,
68
- "A10G": 70.0e12,
69
- }
70
- bf16_peak = 38.1e12 # default to RTX 3060
71
- for key, val in bf16_flops_table.items():
72
- if key in gpu_name:
73
- bf16_peak = val
74
- break
75
-
76
- # Memory budget: leave overhead_mb for CUDA context
77
- overhead = 500
78
- available = mem_mb - overhead
79
- model_budget = int(available * 0.3) # 30% for params + optimizer
80
- activation_budget = int(available * 0.7) # 70% for activations
81
-
82
- return cls(
83
- gpu_name=gpu_name,
84
- gpu_memory_mb=mem_mb,
85
- gpu_vram_mb=mem_mb,
86
- compute_capability=cap,
87
- peak_flops=peak,
88
- bf16_peak_flops=bf16_peak,
89
- model_budget_mb=model_budget,
90
- activation_budget_mb=activation_budget,
91
- )
92
-
93
- def suggest_batch_size(self, d_model: int, seq_len: int, n_layer: int) -> int:
94
- """Suggest batch size based on activation budget.
95
-
96
- Uses rough estimate: per-sample activation ~= n_layer * seq_len * d_model
97
- * 4 bytes * 2 (fwd + bwd).
98
- """
99
- per_sample_mb = n_layer * seq_len * d_model * 4 * 2 / (1024 * 1024)
100
- if per_sample_mb <= 0:
101
- return 1
102
- batch = max(1, int(self.activation_budget_mb / per_sample_mb))
103
- # Round down to power of 2
104
- return 2 ** (batch.bit_length() - 1) if batch > 1 else 1
 
1
+ """Hardware detection and memory budget configuration."""
2
+ from __future__ import annotations
3
+
4
+ import torch
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class HardwareConfig(BaseModel):
9
+ """Auto-detected hardware configuration with memory budgets."""
10
+
11
+ gpu_name: str = Field(default="unknown", description="GPU device name")
12
+ gpu_memory_mb: int = Field(default=0, description="Total GPU memory in MB")
13
+ gpu_vram_mb: int = Field(default=0, description="Alias for gpu_memory_mb (legacy compat)")
14
+ compute_capability: tuple[int, int] = Field(
15
+ default=(0, 0), description="CUDA compute capability"
16
+ )
17
+ peak_flops: float = Field(
18
+ default=12.74e12, description="Peak FP32 FLOPS for MFU calculation"
19
+ )
20
+ bf16_peak_flops: float = Field(
21
+ default=38.1e12, description="Peak BF16 FLOPS (RTX 3060 default)"
22
+ )
23
+
24
+ # Memory budget
25
+ model_budget_mb: int = Field(
26
+ default=1500, description="Max MB for model params + optimizer"
27
+ )
28
+ activation_budget_mb: int = Field(
29
+ default=3000, description="Max MB for activations"
30
+ )
31
+ overhead_mb: int = Field(
32
+ default=500, description="Reserved for CUDA context + PyTorch overhead"
33
+ )
34
+ max_vram_usage_pct: float = Field(
35
+ default=90.0, description="Max VRAM usage as % of total"
36
+ )
37
+ gradient_checkpointing: bool = Field(
38
+ default=False, description="Enable gradient checkpointing to save VRAM"
39
+ )
40
+
41
+ @classmethod
42
+ def detect(cls) -> HardwareConfig:
43
+ """Auto-detect hardware from current CUDA device."""
44
+ if not torch.cuda.is_available():
45
+ return cls()
46
+
47
+ device = torch.cuda.current_device()
48
+ props = torch.cuda.get_device_properties(device)
49
+ cap = (props.major, props.minor)
50
+ mem_mb = props.total_memory // (1024 * 1024)
51
+ gpu_name = props.name
52
+
53
+ # Peak FP32 FLOPS lookup by compute capability (approximate)
54
+ fp32_flops_table: dict[tuple[int, int], float] = {
55
+ (8, 6): 12.74e12, # RTX 3060
56
+ (8, 9): 40.09e12, # RTX 4090
57
+ (9, 0): 989.5e12, # H100 (BF16)
58
+ }
59
+ peak = fp32_flops_table.get(cap, 12.74e12)
60
+
61
+ # BF16 peak FLOPS lookup by GPU name substring
62
+ bf16_flops_table: dict[str, float] = {
63
+ "3060": 38.1e12,
64
+ "3090": 71.0e12,
65
+ "4090": 165.2e12,
66
+ "A100": 312e12,
67
+ "H100": 989.5e12,
68
+ "A10G": 70.0e12,
69
+ }
70
+ bf16_peak = 38.1e12 # default to RTX 3060
71
+ for key, val in bf16_flops_table.items():
72
+ if key in gpu_name:
73
+ bf16_peak = val
74
+ break
75
+
76
+ # Memory budget: leave overhead_mb for CUDA context
77
+ overhead = 500
78
+ available = mem_mb - overhead
79
+ model_budget = int(available * 0.3) # 30% for params + optimizer
80
+ activation_budget = int(available * 0.7) # 70% for activations
81
+
82
+ return cls(
83
+ gpu_name=gpu_name,
84
+ gpu_memory_mb=mem_mb,
85
+ gpu_vram_mb=mem_mb,
86
+ compute_capability=cap,
87
+ peak_flops=peak,
88
+ bf16_peak_flops=bf16_peak,
89
+ model_budget_mb=model_budget,
90
+ activation_budget_mb=activation_budget,
91
+ )
92
+
93
+ def suggest_batch_size(self, d_model: int, seq_len: int, n_layer: int) -> int:
94
+ """Suggest batch size based on activation budget.
95
+
96
+ Uses rough estimate: per-sample activation ~= n_layer * seq_len * d_model
97
+ * 4 bytes * 2 (fwd + bwd).
98
+ """
99
+ per_sample_mb = n_layer * seq_len * d_model * 4 * 2 / (1024 * 1024)
100
+ if per_sample_mb <= 0:
101
+ return 1
102
+ batch = max(1, int(self.activation_budget_mb / per_sample_mb))
103
+ # Round down to power of 2
104
+ return 2 ** (batch.bit_length() - 1) if batch > 1 else 1
overlay/configs/harness_config.py CHANGED
@@ -3,53 +3,53 @@ from typing import Literal
3
 
4
  from pydantic import BaseModel, Field
5
 
6
- type GateThresholds = dict[str, float]
7
- type GateConfig = dict[str, GateThresholds]
8
-
9
-
10
  class HarnessConfig(BaseModel):
11
- """Configuration for the HYDRA harness behavior."""
12
-
13
- # Inner loop
14
- time_budget_seconds: int = Field(
15
- default=300, ge=60, description="Training time budget per experiment in seconds"
16
- )
17
- max_experiments: int = Field(
18
- default=1000, ge=0, description="Max experiments before stopping (0=infinite)"
19
- )
20
-
21
- # Meta-agent
22
- meta_interval: int = Field(
23
- default=20, ge=5, description="Run meta-agent every N experiments"
24
- )
25
- max_meta_changes: int = Field(
26
- default=3, ge=1, le=10, description="Max changes per meta-iteration"
27
- )
28
-
29
- # Search strategy
30
- exploration_mode: Literal["conservative", "balanced", "bold"] = "balanced"
31
- exploration_budget: int = Field(
32
- default=5, ge=1, description="Consecutive bold experiments when stuck"
33
- )
34
- stuck_threshold: int = Field(
35
- default=10, ge=3, description="No improvement for N experiments = stuck"
36
- )
37
- crash_threshold: float = Field(
38
- default=0.5,
39
- ge=0.1,
40
- le=1.0,
41
- description="Crash rate threshold for BROKEN state",
42
- )
43
- regression_tolerance: float = Field(
44
- default=0.05,
45
- ge=0,
46
- le=0.2,
47
- description="Max val_bpb regression from best (fraction)",
48
- )
49
- max_regression_pct: float = Field(
50
- default=5.0, description="Max % regression from best known val_bpb"
51
- )
52
-
53
  # Keep/discard criteria
54
  primary_metric: str = "val_bpb"
55
  secondary_metrics: GateConfig = Field(
@@ -63,23 +63,23 @@ class HarnessConfig(BaseModel):
63
  "hestia_quant_error": {"max": 0.05},
64
  }
65
  )
66
-
67
- # Experiment execution
68
- experiment_timeout: int = Field(
69
- default=600, ge=300, description="Kill experiment after N seconds"
70
- )
71
- warmup_steps: int = Field(
72
- default=10, ge=0, description="Steps to exclude from timing"
73
- )
74
-
75
- # Git
76
- branch_prefix: str = Field(default="autoresearch", description="Branch naming prefix")
77
- results_file: str = Field(default="results.tsv", description="Experiment log file")
78
-
79
- # Secondary metric gates (optional keep/discard criteria)
80
- gate_mhc_spectral_norm: float | None = Field(
81
- default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
82
- )
83
  gate_engram_hit_rate: float | None = Field(
84
  default=None, description="Min engram_hit_rate for keep (None=disabled)"
85
  )
 
3
 
4
  from pydantic import BaseModel, Field
5
 
6
+ GateThresholds = dict[str, float]
7
+ GateConfig = dict[str, GateThresholds]
8
+
9
+
10
  class HarnessConfig(BaseModel):
11
+ """Configuration for the HYDRA harness behavior."""
12
+
13
+ # Inner loop
14
+ time_budget_seconds: int = Field(
15
+ default=300, ge=60, description="Training time budget per experiment in seconds"
16
+ )
17
+ max_experiments: int = Field(
18
+ default=1000, ge=0, description="Max experiments before stopping (0=infinite)"
19
+ )
20
+
21
+ # Meta-agent
22
+ meta_interval: int = Field(
23
+ default=20, ge=5, description="Run meta-agent every N experiments"
24
+ )
25
+ max_meta_changes: int = Field(
26
+ default=3, ge=1, le=10, description="Max changes per meta-iteration"
27
+ )
28
+
29
+ # Search strategy
30
+ exploration_mode: Literal["conservative", "balanced", "bold"] = "balanced"
31
+ exploration_budget: int = Field(
32
+ default=5, ge=1, description="Consecutive bold experiments when stuck"
33
+ )
34
+ stuck_threshold: int = Field(
35
+ default=10, ge=3, description="No improvement for N experiments = stuck"
36
+ )
37
+ crash_threshold: float = Field(
38
+ default=0.5,
39
+ ge=0.1,
40
+ le=1.0,
41
+ description="Crash rate threshold for BROKEN state",
42
+ )
43
+ regression_tolerance: float = Field(
44
+ default=0.05,
45
+ ge=0,
46
+ le=0.2,
47
+ description="Max val_bpb regression from best (fraction)",
48
+ )
49
+ max_regression_pct: float = Field(
50
+ default=5.0, description="Max % regression from best known val_bpb"
51
+ )
52
+
53
  # Keep/discard criteria
54
  primary_metric: str = "val_bpb"
55
  secondary_metrics: GateConfig = Field(
 
63
  "hestia_quant_error": {"max": 0.05},
64
  }
65
  )
66
+
67
+ # Experiment execution
68
+ experiment_timeout: int = Field(
69
+ default=600, ge=300, description="Kill experiment after N seconds"
70
+ )
71
+ warmup_steps: int = Field(
72
+ default=10, ge=0, description="Steps to exclude from timing"
73
+ )
74
+
75
+ # Git
76
+ branch_prefix: str = Field(default="autoresearch", description="Branch naming prefix")
77
+ results_file: str = Field(default="results.tsv", description="Experiment log file")
78
+
79
+ # Secondary metric gates (optional keep/discard criteria)
80
+ gate_mhc_spectral_norm: float | None = Field(
81
+ default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
82
+ )
83
  gate_engram_hit_rate: float | None = Field(
84
  default=None, description="Min engram_hit_rate for keep (None=disabled)"
85
  )
overlay/configs/model_config.py CHANGED
@@ -1,80 +1,80 @@
1
- """Post-SEM-Claw model configuration with Pydantic validation."""
2
- from pydantic import BaseModel, Field, field_validator
3
-
4
-
5
- class PostSemClawConfig(BaseModel):
6
- """Configuration for the Post-SEM-Claw architecture.
7
-
8
- Default values mirror the @dataclass in train.py exactly.
9
- train.py is the source of truth — this file must stay in sync with it.
10
- """
11
-
12
- # Sequence
13
- sequence_len: int = Field(default=2048, description="Context length (from prepare.py MAX_SEQ_LEN)")
14
- vocab_size: int = Field(default=8192, description="Vocabulary size (from prepare.py VOCAB_SIZE)")
15
-
16
- # Mamba-3 SSM
17
- n_layer: int = Field(default=4, ge=1, le=48, description="Number of Mamba-3 blocks")
18
- d_model: int = Field(default=256, ge=64, description="Model embedding dimension")
19
- d_state: int = Field(default=64, ge=16, description="SSM state dimension")
20
- headdim: int = Field(default=32, ge=16, description="SSM head dimension")
21
- n_heads: int = Field(default=8, ge=1, description="Number of SSM heads (d_model // headdim)")
22
- expand: int = Field(default=2, ge=1, le=4, description="Inner dim multiplier (inner_dim = expand * d_model)")
23
-
24
- # mHC (Manifold Hyper-Connection)
25
- mhc_n_streams: int = Field(default=4, ge=2, le=8, description="Number of residual streams")
26
- mhc_sinkhorn_iters: int = Field(default=5, ge=1, le=100, description="Sinkhorn-Knopp iterations")
27
-
28
- # Engram (conditional memory)
29
- engram_n_columns: int = Field(default=4096, ge=256, description="Hash table columns")
30
- engram_key_dim: int = Field(default=64, ge=16, description="Engram key dimension")
31
- engram_layer_idx: int = Field(default=1, ge=0, description="Which layer gets engram (0-indexed)")
32
-
33
- # Hestia QAT (disabled Phase 1, skeleton only)
34
- hestia_enabled: bool = Field(default=False, description="Enable Hestia quantization")
35
- hestia_bits: float = Field(default=1.58, gt=0, description="Target quantization bits (1.58 = 1.58-bit ternary)")
36
-
37
- # SDR (bypass-only in Phase 1)
38
- sdr_enabled: bool = Field(default=False, description="Enable stochastic resonance")
39
- sdr_k: int = Field(default=64, ge=1, description="Top-K sparsification")
40
- sdr_noise_std: float = Field(default=0.1, ge=0.0, description="SR noise standard deviation")
41
-
42
- @field_validator("n_heads")
43
- @classmethod
44
- def validate_heads(cls, v: int, info: "FieldValidationInfo") -> int:
45
- """Ensure n_heads equals d_model // headdim."""
46
- d_model = info.data.get("d_model", 256)
47
- headdim = info.data.get("headdim", 32)
48
- expected = d_model // headdim
49
- if v != expected:
50
- raise ValueError(
51
- f"n_heads ({v}) must equal d_model // headdim ({expected})"
52
- )
53
- return v
54
-
55
- def estimate_params(self) -> int:
56
- """Rough parameter count estimate based on train.py architecture."""
57
- inner = self.expand * self.d_model
58
- # in_proj: d_model -> inner + inner + d_state + d_state + n_heads
59
- in_proj = self.d_model * (inner + inner + self.d_state + self.d_state + self.n_heads)
60
- out_proj = inner * self.d_model
61
- # conv1d (kernel=4, groups=inner_dim)
62
- conv = inner * 4
63
- # A_log, lambda_theta, D: n_heads each (3 vectors)
64
- ssm_params = self.n_heads * 3
65
- # bc_norm: d_state * 2 (weight + bias)
66
- bc_norm = self.d_state * 2
67
- per_block = in_proj + out_proj + conv + ssm_params + bc_norm
68
- blocks = per_block * self.n_layer
69
-
70
- # Embedding + lm_head (tied or untied)
71
- embed = self.vocab_size * self.d_model * 2
72
-
73
- # Engram: one instance at engram_layer_idx
74
- # columns * d_model keys + d_model * engram_key_dim projection
75
- engram = self.engram_n_columns * self.d_model + self.d_model * self.engram_key_dim
76
-
77
- # mHC mixing matrices: n_layer * mhc_n_streams^2
78
- mhc = self.n_layer * self.mhc_n_streams ** 2
79
-
80
- return embed + blocks + engram + mhc
 
1
+ """Post-SEM-Claw model configuration with Pydantic validation."""
2
+ from pydantic import BaseModel, Field, field_validator
3
+
4
+
5
+ class PostSemClawConfig(BaseModel):
6
+ """Configuration for the Post-SEM-Claw architecture.
7
+
8
+ Default values mirror the @dataclass in train.py exactly.
9
+ train.py is the source of truth — this file must stay in sync with it.
10
+ """
11
+
12
+ # Sequence
13
+ sequence_len: int = Field(default=2048, description="Context length (from prepare.py MAX_SEQ_LEN)")
14
+ vocab_size: int = Field(default=8192, description="Vocabulary size (from prepare.py VOCAB_SIZE)")
15
+
16
+ # Mamba-3 SSM
17
+ n_layer: int = Field(default=4, ge=1, le=48, description="Number of Mamba-3 blocks")
18
+ d_model: int = Field(default=256, ge=64, description="Model embedding dimension")
19
+ d_state: int = Field(default=64, ge=16, description="SSM state dimension")
20
+ headdim: int = Field(default=32, ge=16, description="SSM head dimension")
21
+ n_heads: int = Field(default=8, ge=1, description="Number of SSM heads (d_model // headdim)")
22
+ expand: int = Field(default=2, ge=1, le=4, description="Inner dim multiplier (inner_dim = expand * d_model)")
23
+
24
+ # mHC (Manifold Hyper-Connection)
25
+ mhc_n_streams: int = Field(default=4, ge=2, le=8, description="Number of residual streams")
26
+ mhc_sinkhorn_iters: int = Field(default=5, ge=1, le=100, description="Sinkhorn-Knopp iterations")
27
+
28
+ # Engram (conditional memory)
29
+ engram_n_columns: int = Field(default=4096, ge=256, description="Hash table columns")
30
+ engram_key_dim: int = Field(default=64, ge=16, description="Engram key dimension")
31
+ engram_layer_idx: int = Field(default=1, ge=0, description="Which layer gets engram (0-indexed)")
32
+
33
+ # Hestia QAT (disabled Phase 1, skeleton only)
34
+ hestia_enabled: bool = Field(default=False, description="Enable Hestia quantization")
35
+ hestia_bits: float = Field(default=1.58, gt=0, description="Target quantization bits (1.58 = 1.58-bit ternary)")
36
+
37
+ # SDR (bypass-only in Phase 1)
38
+ sdr_enabled: bool = Field(default=False, description="Enable stochastic resonance")
39
+ sdr_k: int = Field(default=64, ge=1, description="Top-K sparsification")
40
+ sdr_noise_std: float = Field(default=0.1, ge=0.0, description="SR noise standard deviation")
41
+
42
+ @field_validator("n_heads")
43
+ @classmethod
44
+ def validate_heads(cls, v: int, info: "FieldValidationInfo") -> int:
45
+ """Ensure n_heads equals d_model // headdim."""
46
+ d_model = info.data.get("d_model", 256)
47
+ headdim = info.data.get("headdim", 32)
48
+ expected = d_model // headdim
49
+ if v != expected:
50
+ raise ValueError(
51
+ f"n_heads ({v}) must equal d_model // headdim ({expected})"
52
+ )
53
+ return v
54
+
55
+ def estimate_params(self) -> int:
56
+ """Rough parameter count estimate based on train.py architecture."""
57
+ inner = self.expand * self.d_model
58
+ # in_proj: d_model -> inner + inner + d_state + d_state + n_heads
59
+ in_proj = self.d_model * (inner + inner + self.d_state + self.d_state + self.n_heads)
60
+ out_proj = inner * self.d_model
61
+ # conv1d (kernel=4, groups=inner_dim)
62
+ conv = inner * 4
63
+ # A_log, lambda_theta, D: n_heads each (3 vectors)
64
+ ssm_params = self.n_heads * 3
65
+ # bc_norm: d_state * 2 (weight + bias)
66
+ bc_norm = self.d_state * 2
67
+ per_block = in_proj + out_proj + conv + ssm_params + bc_norm
68
+ blocks = per_block * self.n_layer
69
+
70
+ # Embedding + lm_head (tied or untied)
71
+ embed = self.vocab_size * self.d_model * 2
72
+
73
+ # Engram: one instance at engram_layer_idx
74
+ # columns * d_model keys + d_model * engram_key_dim projection
75
+ engram = self.engram_n_columns * self.d_model + self.d_model * self.engram_key_dim
76
+
77
+ # mHC mixing matrices: n_layer * mhc_n_streams^2
78
+ mhc = self.n_layer * self.mhc_n_streams ** 2
79
+
80
+ return embed + blocks + engram + mhc
overlay/harness/__init__.py CHANGED
@@ -1,21 +1,21 @@
1
- """HYDRA harness package: orchestration infrastructure for autoresearch."""
2
- from harness.eval_agent import ExperimentResult, parse_run_log, should_keep
3
- from harness.git_utils import current_branch, current_commit_short
4
- from harness.health_monitor import check_health, get_gpu_stats
5
- from harness.meta_agent import run_meta_iteration
6
- from harness.orchestrator import run_loop
7
- from harness.search_strategy import ResearchState, diagnose
8
-
9
- __all__ = [
10
- "run_loop",
11
- "parse_run_log",
12
- "ExperimentResult",
13
- "should_keep",
14
- "run_meta_iteration",
15
- "diagnose",
16
- "ResearchState",
17
- "check_health",
18
- "get_gpu_stats",
19
- "current_branch",
20
- "current_commit_short",
21
- ]
 
1
+ """HYDRA harness package: orchestration infrastructure for autoresearch."""
2
+ from harness.eval_agent import ExperimentResult, parse_run_log, should_keep
3
+ from harness.git_utils import current_branch, current_commit_short
4
+ from harness.health_monitor import check_health, get_gpu_stats
5
+ from harness.meta_agent import run_meta_iteration
6
+ from harness.orchestrator import run_loop
7
+ from harness.search_strategy import ResearchState, diagnose
8
+
9
+ __all__ = [
10
+ "run_loop",
11
+ "parse_run_log",
12
+ "ExperimentResult",
13
+ "should_keep",
14
+ "run_meta_iteration",
15
+ "diagnose",
16
+ "ResearchState",
17
+ "check_health",
18
+ "get_gpu_stats",
19
+ "current_branch",
20
+ "current_commit_short",
21
+ ]
overlay/harness/eval_agent.py CHANGED
@@ -1,300 +1,172 @@
1
  """Eval agent: parse run.log and extract metrics from training runs."""
2
  import re
3
- import statistics
4
- from dataclasses import dataclass
5
 
6
 
7
- type GateThresholds = dict[str, float]
8
- type GateConfig = dict[str, GateThresholds]
9
-
10
-
11
- @dataclass
12
  class ExperimentResult:
13
- """Parsed result from a single experiment run.
14
-
15
- All float fields default to 0.0; integer fields default to 0.
16
- The ``crashed`` flag is set when the log indicates a failure or the
17
- log file is missing entirely.
18
- """
19
-
20
- # Primary metric
21
- val_bpb: float = 0.0
22
-
23
- # Timing
24
- training_seconds: float = 0.0
25
- total_seconds: float = 0.0
26
-
27
- # Hardware
28
- peak_vram_mb: float = 0.0
29
- mfu_percent: float = 0.0
30
-
31
  # Throughput
32
  total_tokens_m: float = 0.0
33
  num_steps: int = 0
34
- tps_median: float = 0.0
35
- tps_p10: float = 0.0
36
- tps_min: float = 0.0
37
- tps_max: float = 0.0
38
- tps_samples: int = 0
39
-
40
- # Model shape (echoed by train.py summary block)
41
- num_params_m: float = 0.0
42
- n_layer: int = 0
43
- d_model: int = 0
44
-
45
  # Secondary health metrics
46
  mhc_spectral_norm: float = 0.0
47
  engram_hit_rate: float = 0.0
48
  sr_bypass_rate: float = 0.0
49
 
50
- # Evaluation breadth metrics
51
- factual_english_score: float = 0.0
52
- instruction_following_score: float = 0.0
53
- distinct_1: float = 0.0
54
- distinct_2: float = 0.0
55
- repetition_rate: float = 0.0
56
- repetition_bigram_rate: float = 0.0
57
- calibration_ece: float = 0.0
58
- calibration_brier: float = 0.0
59
- calibration_accuracy: float = 0.0
60
- calibration_tokens: int = 0
61
- eval_seed: int = 0
62
- eval_seed_group: str = ""
63
-
64
- # Status
65
- crashed: bool = False
66
- error_message: str = ""
67
-
68
-
69
- # Regex patterns keyed by ExperimentResult attribute name.
70
- # Format must match the ``--- Summary ---`` block printed by train.py.
71
- _PATTERNS: dict[str, str] = {
72
- "val_bpb": r"^val_bpb:\s+([\d.]+)",
73
- "training_seconds": r"^training_seconds:\s+([\d.]+)",
74
- "total_seconds": r"^total_seconds:\s+([\d.]+)",
75
- "peak_vram_mb": r"^peak_vram_mb:\s+([\d.]+)",
76
- "mfu_percent": r"^mfu_percent:\s+([\d.]+)",
77
- "total_tokens_m": r"^total_tokens_M:\s+([\d.]+)",
78
- "num_steps": r"^num_steps:\s+(\d+)",
79
- "num_params_m": r"^num_params_M:\s+([\d.]+)",
80
- "n_layer": r"^n_layer:\s+(\d+)",
81
- "d_model": r"^d_model:\s+(\d+)",
82
- "mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
83
  "engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
84
  "sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
85
- "factual_english_score": r"^factual_english_score:\s+([\d.]+)",
86
- "instruction_following_score": r"^instruction_following_score:\s+([\d.]+)",
87
- "distinct_1": r"^distinct_1:\s+([\d.]+)",
88
- "distinct_2": r"^distinct_2:\s+([\d.]+)",
89
- "repetition_rate": r"^repetition_rate:\s+([\d.]+)",
90
- "repetition_bigram_rate": r"^repetition_bigram_rate:\s+([\d.]+)",
91
- "calibration_ece": r"^calibration_ece:\s+([\d.]+)",
92
- "calibration_brier": r"^calibration_brier:\s*([\d.]+)",
93
- "calibration_accuracy": r"^calibration_accuracy:\s+([\d.]+)",
94
- "calibration_tokens": r"^calibration_tokens:\s+(\d+)",
95
- "eval_seed": r"^eval_seed:\s+(\d+)",
96
- "eval_seed_group": r"^eval_seed_group:\s+(.+)",
97
  }
98
-
99
- # Attributes that should be parsed as int rather than float.
100
- _INT_ATTRS: frozenset[str] = frozenset(
101
- {
102
- "num_steps",
103
- "n_layer",
104
- "d_model",
105
- "calibration_tokens",
106
- "eval_seed",
107
- }
108
- )
109
- _STR_ATTRS: frozenset[str] = frozenset({"eval_seed_group"})
110
- _STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
111
- _TPS_PATTERN = re.compile(r"\btps=(\d+)\b")
112
-
113
-
114
- def _percentile_linear(sorted_values: list[float], pct: float) -> float:
115
- """Compute percentile via linear interpolation (0 <= pct <= 100)."""
116
- if not sorted_values:
117
- return 0.0
118
- if len(sorted_values) == 1:
119
- return sorted_values[0]
120
- rank = (len(sorted_values) - 1) * (pct / 100.0)
121
- lo = int(rank)
122
- hi = min(lo + 1, len(sorted_values) - 1)
123
- frac = rank - lo
124
- return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
125
-
126
-
127
- def parse_run_log(log_path: str) -> ExperimentResult:
128
- """Parse a run.log file and extract all training metrics.
129
-
130
- Args:
131
- log_path: Absolute path to the run.log file.
132
-
133
- Returns:
134
- Populated ExperimentResult; sets ``crashed=True`` when the log
135
- contains a traceback or the file is missing.
136
- """
137
- result = ExperimentResult()
138
-
139
- try:
140
- with open(log_path) as fh:
141
- content = fh.read()
142
- except FileNotFoundError:
143
- result.crashed = True
144
- result.error_message = f"Log file not found: {log_path}"
145
- return result
146
-
147
- # Detect crash signals in output. Keep this strict to avoid false positives
148
- # from benign log lines that include "error" in a non-fatal context.
149
- if (
150
- "Traceback" in content
151
- or "\nFAIL\n" in content
152
- or "[TPS_GUARD] FAIL" in content
153
- or "raise SystemExit(1)" in content
154
- ):
155
  result.crashed = True
156
  lines = content.strip().splitlines()
157
  result.error_message = "\n".join(lines[-20:])
158
-
159
  for attr, pattern in _PATTERNS.items():
160
  match = re.search(pattern, content, re.MULTILINE)
161
  if match:
162
  raw = match.group(1)
163
- if attr in _INT_ATTRS:
164
- setattr(result, attr, int(raw))
165
- elif attr in _STR_ATTRS:
166
- setattr(result, attr, raw.strip())
167
- else:
168
- setattr(result, attr, float(raw))
169
-
170
- warmup_steps = 10
171
- warmup_match = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", content)
172
- if warmup_match:
173
- warmup_steps = int(warmup_match.group(1))
174
-
175
- step_tps_samples: list[tuple[int, int]] = []
176
- for m in _STEP_TPS_PATTERN.finditer(content):
177
- step_tps_samples.append((int(m.group(1)), int(m.group(2))))
178
-
179
- tps_values: list[float] = []
180
- if step_tps_samples:
181
- for step, tps in step_tps_samples:
182
- if step >= warmup_steps:
183
- tps_values.append(float(tps))
184
- if not tps_values:
185
- tps_values = [float(tps) for _, tps in step_tps_samples]
186
- else:
187
- tps_values = [float(m.group(1)) for m in _TPS_PATTERN.finditer(content)]
188
-
189
- if tps_values:
190
- sorted_tps = sorted(tps_values)
191
- result.tps_samples = len(tps_values)
192
- result.tps_median = float(statistics.median(tps_values))
193
- result.tps_p10 = float(_percentile_linear(sorted_tps, 10.0))
194
- result.tps_min = float(sorted_tps[0])
195
- result.tps_max = float(sorted_tps[-1])
196
 
197
  return result
198
-
199
-
200
  def check_secondary_alarms(result: ExperimentResult) -> list[str]:
201
- """Check secondary metrics against fixed alarm thresholds.
202
-
203
- Args:
204
- result: Parsed experiment result.
205
-
206
- Returns:
207
- List of human-readable alarm strings (empty if all clear).
208
- """
209
- alarms: list[str] = []
210
-
211
- if result.mhc_spectral_norm > 2.0:
212
- alarms.append(
213
- f"mhc_spectral_norm={result.mhc_spectral_norm:.4f} > 2.0 (ALARM)"
214
- )
215
- if 0 < result.engram_hit_rate < 0.1:
216
- alarms.append(
217
- f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
218
- )
219
- if 0 < result.mfu_percent < 10:
220
  alarms.append(
221
- f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
222
  )
223
- if result.calibration_ece > 0.35:
224
  alarms.append(
225
- f"calibration_ece={result.calibration_ece:.4f} > 0.35 (poor calibration)"
226
  )
227
- if result.tps_median > 0 and result.tps_median < 50000:
228
  alarms.append(
229
- f"tps_median={result.tps_median:.0f} < 50000 (throughput below A10 objective)"
230
  )
231
-
232
  return alarms
233
 
234
 
235
- def _check_gate(
236
- result: ExperimentResult,
237
- gates: GateConfig,
238
- metric: str,
239
- ) -> tuple[bool, str] | None:
240
- """Evaluate a single min/max gate against an ExperimentResult metric."""
241
- gate = gates.get(metric, {})
242
- value = getattr(result, metric)
243
- max_value = gate.get("max")
244
- if max_value is not None and value > max_value:
245
- return False, f"{metric} {value:.4f} > gate {max_value}"
246
- min_value = gate.get("min")
247
- if min_value is not None and value < min_value:
248
- return False, f"{metric} {value:.4f} < gate {min_value}"
249
- return None
250
-
251
-
252
  def should_keep(
253
  result: ExperimentResult,
254
  best_bpb: float,
255
- gates: GateConfig | None = None,
256
  ) -> tuple[bool, str]:
257
- """Decide whether to keep or discard an experiment.
258
-
259
- The primary criterion is strictly lower val_bpb than the current best.
260
- Optional secondary gates (passed from HarnessConfig.secondary_metrics)
261
- can reject an otherwise-improving result.
262
-
263
- Args:
264
- result: Parsed experiment result.
265
- best_bpb: Current best val_bpb across all experiments.
266
- gates: Optional dict mapping metric name to threshold dict with
267
- ``"max"`` or ``"min"`` keys, e.g.
268
- ``{"mhc_spectral_norm": {"max": 2.0}}``.
269
-
270
- Returns:
271
- Tuple of (keep: bool, reason: str).
272
- """
273
- if result.crashed:
274
- return False, "crash"
275
- if result.val_bpb <= 0:
276
- return False, "invalid val_bpb"
277
- if result.val_bpb >= best_bpb:
278
- return False, "discard"
279
-
280
  # Secondary gate checks.
281
  if gates:
282
- gate_metrics = (
283
- "mhc_spectral_norm",
284
- "engram_hit_rate",
285
- "factual_english_score",
286
- "instruction_following_score",
287
- "distinct_1",
288
- "distinct_2",
289
- "repetition_rate",
290
- "repetition_bigram_rate",
291
- "calibration_ece",
292
- "tps_median",
293
- "tps_p10",
294
- )
295
- for metric in gate_metrics:
296
- gate_result = _check_gate(result, gates, metric)
297
- if gate_result is not None:
298
- return gate_result
299
 
300
  return True, "keep"
 
1
  """Eval agent: parse run.log and extract metrics from training runs."""
2
  import re
3
+ from dataclasses import dataclass, field
 
4
 
5
 
6
+ @dataclass
 
 
 
 
7
  class ExperimentResult:
8
+ """Parsed result from a single experiment run.
9
+
10
+ All float fields default to 0.0; integer fields default to 0.
11
+ The ``crashed`` flag is set when the log indicates a failure or the
12
+ log file is missing entirely.
13
+ """
14
+
15
+ # Primary metric
16
+ val_bpb: float = 0.0
17
+
18
+ # Timing
19
+ training_seconds: float = 0.0
20
+ total_seconds: float = 0.0
21
+
22
+ # Hardware
23
+ peak_vram_mb: float = 0.0
24
+ mfu_percent: float = 0.0
25
+
26
  # Throughput
27
  total_tokens_m: float = 0.0
28
  num_steps: int = 0
29
+
30
+ # Model shape (echoed by train.py summary block)
31
+ num_params_m: float = 0.0
32
+ n_layer: int = 0
33
+ d_model: int = 0
34
+
 
 
 
 
 
35
  # Secondary health metrics
36
  mhc_spectral_norm: float = 0.0
37
  engram_hit_rate: float = 0.0
38
  sr_bypass_rate: float = 0.0
39
 
40
+ # Status
41
+ crashed: bool = False
42
+ error_message: str = ""
43
+
44
+
45
+ # Regex patterns keyed by ExperimentResult attribute name.
46
+ # Format must match the ``--- Summary ---`` block printed by train.py.
47
+ _PATTERNS: dict[str, str] = {
48
+ "val_bpb": r"^val_bpb:\s+([\d.]+)",
49
+ "training_seconds": r"^training_seconds:\s+([\d.]+)",
50
+ "total_seconds": r"^total_seconds:\s+([\d.]+)",
51
+ "peak_vram_mb": r"^peak_vram_mb:\s+([\d.]+)",
52
+ "mfu_percent": r"^mfu_percent:\s+([\d.]+)",
53
+ "total_tokens_m": r"^total_tokens_M:\s+([\d.]+)",
54
+ "num_steps": r"^num_steps:\s+(\d+)",
55
+ "num_params_m": r"^num_params_M:\s+([\d.]+)",
56
+ "n_layer": r"^n_layer:\s+(\d+)",
57
+ "d_model": r"^d_model:\s+(\d+)",
58
+ "mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  "engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
60
  "sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
 
 
 
 
 
 
 
 
 
 
 
 
61
  }
62
+
63
+ # Attributes that should be parsed as int rather than float.
64
+ _INT_ATTRS: frozenset[str] = frozenset({"num_steps", "n_layer", "d_model"})
65
+
66
+
67
+ def parse_run_log(log_path: str) -> ExperimentResult:
68
+ """Parse a run.log file and extract all training metrics.
69
+
70
+ Args:
71
+ log_path: Absolute path to the run.log file.
72
+
73
+ Returns:
74
+ Populated ExperimentResult; sets ``crashed=True`` when the log
75
+ contains a traceback or the file is missing.
76
+ """
77
+ result = ExperimentResult()
78
+
79
+ try:
80
+ with open(log_path) as fh:
81
+ content = fh.read()
82
+ except FileNotFoundError:
83
+ result.crashed = True
84
+ result.error_message = f"Log file not found: {log_path}"
85
+ return result
86
+
87
+ # Detect crash signals in output.
88
+ if "Traceback" in content or "FAIL" in content or "Error" in content:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  result.crashed = True
90
  lines = content.strip().splitlines()
91
  result.error_message = "\n".join(lines[-20:])
92
+
93
  for attr, pattern in _PATTERNS.items():
94
  match = re.search(pattern, content, re.MULTILINE)
95
  if match:
96
  raw = match.group(1)
97
+ setattr(result, attr, int(raw) if attr in _INT_ATTRS else float(raw))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  return result
100
+
101
+
102
  def check_secondary_alarms(result: ExperimentResult) -> list[str]:
103
+ """Check secondary metrics against fixed alarm thresholds.
104
+
105
+ Args:
106
+ result: Parsed experiment result.
107
+
108
+ Returns:
109
+ List of human-readable alarm strings (empty if all clear).
110
+ """
111
+ alarms: list[str] = []
112
+
113
+ if result.mhc_spectral_norm > 2.0:
 
 
 
 
 
 
 
 
114
  alarms.append(
115
+ f"mhc_spectral_norm={result.mhc_spectral_norm:.4f} > 2.0 (ALARM)"
116
  )
117
+ if 0 < result.engram_hit_rate < 0.1:
118
  alarms.append(
119
+ f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
120
  )
121
+ if 0 < result.mfu_percent < 10:
122
  alarms.append(
123
+ f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
124
  )
125
+
126
  return alarms
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def should_keep(
130
  result: ExperimentResult,
131
  best_bpb: float,
132
+ gates: dict | None = None,
133
  ) -> tuple[bool, str]:
134
+ """Decide whether to keep or discard an experiment.
135
+
136
+ The primary criterion is strictly lower val_bpb than the current best.
137
+ Optional secondary gates (passed from HarnessConfig.secondary_metrics)
138
+ can reject an otherwise-improving result.
139
+
140
+ Args:
141
+ result: Parsed experiment result.
142
+ best_bpb: Current best val_bpb across all experiments.
143
+ gates: Optional dict mapping metric name to threshold dict with
144
+ ``"max"`` or ``"min"`` keys, e.g.
145
+ ``{"mhc_spectral_norm": {"max": 2.0}}``.
146
+
147
+ Returns:
148
+ Tuple of (keep: bool, reason: str).
149
+ """
150
+ if result.crashed:
151
+ return False, "crash"
152
+ if result.val_bpb <= 0:
153
+ return False, "invalid val_bpb"
154
+ if result.val_bpb >= best_bpb:
155
+ return False, "discard"
156
+
157
  # Secondary gate checks.
158
  if gates:
159
+ gate_mhc = gates.get("mhc_spectral_norm", {}).get("max")
160
+ if gate_mhc is not None and result.mhc_spectral_norm > gate_mhc:
161
+ return (
162
+ False,
163
+ f"mhc_spectral_norm {result.mhc_spectral_norm:.4f} > gate {gate_mhc}",
164
+ )
165
+ gate_engram = gates.get("engram_hit_rate", {}).get("min")
166
+ if gate_engram is not None and result.engram_hit_rate < gate_engram:
167
+ return (
168
+ False,
169
+ f"engram_hit_rate {result.engram_hit_rate:.4f} < gate {gate_engram}",
170
+ )
 
 
 
 
 
171
 
172
  return True, "keep"
overlay/harness/git_utils.py CHANGED
@@ -1,94 +1,94 @@
1
- """Git utilities for HYDRA autoresearch branch management."""
2
- import os
3
- import subprocess
4
-
5
- REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
-
7
-
8
- def run_git(*args: str, check: bool = True) -> subprocess.CompletedProcess:
9
- """Run a git command in the repo directory.
10
-
11
- Args:
12
- *args: Git command arguments.
13
- check: Whether to raise on non-zero exit code.
14
-
15
- Returns:
16
- Completed process with stdout/stderr captured.
17
- """
18
- return subprocess.run(
19
- ["git"] + list(args),
20
- cwd=REPO_DIR,
21
- capture_output=True,
22
- text=True,
23
- check=check,
24
- )
25
-
26
-
27
- def current_branch() -> str:
28
- """Return the current git branch name.
29
-
30
- Returns:
31
- Branch name string.
32
- """
33
- result = run_git("rev-parse", "--abbrev-ref", "HEAD")
34
- return result.stdout.strip()
35
-
36
-
37
- def current_commit_short() -> str:
38
- """Return the current HEAD commit short hash (7 chars).
39
-
40
- Returns:
41
- 7-character commit hash.
42
- """
43
- result = run_git("rev-parse", "--short=7", "HEAD")
44
- return result.stdout.strip()
45
-
46
-
47
- def create_branch(name: str) -> None:
48
- """Create and switch to a new branch.
49
-
50
- Args:
51
- name: Branch name to create.
52
- """
53
- run_git("checkout", "-b", name)
54
-
55
-
56
- def commit_all(message: str) -> str:
57
- """Stage all changes, commit, and return short hash.
58
-
59
- Args:
60
- message: Commit message.
61
-
62
- Returns:
63
- Short commit hash after committing.
64
- """
65
- run_git("add", "-A")
66
- run_git("commit", "-m", message, check=False)
67
- return current_commit_short()
68
-
69
-
70
- def reset_to(commit: str) -> None:
71
- """Hard reset to a specific commit, discarding all changes.
72
-
73
- Args:
74
- commit: Commit hash (short or full) to reset to.
75
- """
76
- run_git("reset", "--hard", commit)
77
-
78
-
79
- def get_last_n_diffs(n: int = 3) -> list[str]:
80
- """Get the last N commit diffs (--stat format) for meta-agent context.
81
-
82
- Args:
83
- n: Number of recent commits to retrieve.
84
-
85
- Returns:
86
- List of diff stat strings, one per commit (truncated to 500 chars).
87
- """
88
- result = run_git("log", f"-{n}", "--format=%H", check=False)
89
- hashes = [h for h in result.stdout.strip().split("\n") if h]
90
- diffs: list[str] = []
91
- for h in hashes:
92
- diff_result = run_git("show", "--stat", h, check=False)
93
- diffs.append(diff_result.stdout[:500])
94
- return diffs
 
1
+ """Git utilities for HYDRA autoresearch branch management."""
2
+ import os
3
+ import subprocess
4
+
5
+ REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+
7
+
8
+ def run_git(*args: str, check: bool = True) -> subprocess.CompletedProcess:
9
+ """Run a git command in the repo directory.
10
+
11
+ Args:
12
+ *args: Git command arguments.
13
+ check: Whether to raise on non-zero exit code.
14
+
15
+ Returns:
16
+ Completed process with stdout/stderr captured.
17
+ """
18
+ return subprocess.run(
19
+ ["git"] + list(args),
20
+ cwd=REPO_DIR,
21
+ capture_output=True,
22
+ text=True,
23
+ check=check,
24
+ )
25
+
26
+
27
+ def current_branch() -> str:
28
+ """Return the current git branch name.
29
+
30
+ Returns:
31
+ Branch name string.
32
+ """
33
+ result = run_git("rev-parse", "--abbrev-ref", "HEAD")
34
+ return result.stdout.strip()
35
+
36
+
37
+ def current_commit_short() -> str:
38
+ """Return the current HEAD commit short hash (7 chars).
39
+
40
+ Returns:
41
+ 7-character commit hash.
42
+ """
43
+ result = run_git("rev-parse", "--short=7", "HEAD")
44
+ return result.stdout.strip()
45
+
46
+
47
+ def create_branch(name: str) -> None:
48
+ """Create and switch to a new branch.
49
+
50
+ Args:
51
+ name: Branch name to create.
52
+ """
53
+ run_git("checkout", "-b", name)
54
+
55
+
56
+ def commit_all(message: str) -> str:
57
+ """Stage all changes, commit, and return short hash.
58
+
59
+ Args:
60
+ message: Commit message.
61
+
62
+ Returns:
63
+ Short commit hash after committing.
64
+ """
65
+ run_git("add", "-A")
66
+ run_git("commit", "-m", message, check=False)
67
+ return current_commit_short()
68
+
69
+
70
+ def reset_to(commit: str) -> None:
71
+ """Hard reset to a specific commit, discarding all changes.
72
+
73
+ Args:
74
+ commit: Commit hash (short or full) to reset to.
75
+ """
76
+ run_git("reset", "--hard", commit)
77
+
78
+
79
+ def get_last_n_diffs(n: int = 3) -> list[str]:
80
+ """Get the last N commit diffs (--stat format) for meta-agent context.
81
+
82
+ Args:
83
+ n: Number of recent commits to retrieve.
84
+
85
+ Returns:
86
+ List of diff stat strings, one per commit (truncated to 500 chars).
87
+ """
88
+ result = run_git("log", f"-{n}", "--format=%H", check=False)
89
+ hashes = [h for h in result.stdout.strip().split("\n") if h]
90
+ diffs: list[str] = []
91
+ for h in hashes:
92
+ diff_result = run_git("show", "--stat", h, check=False)
93
+ diffs.append(diff_result.stdout[:500])
94
+ return diffs
overlay/harness/health_monitor.py CHANGED
@@ -1,86 +1,86 @@
1
- """Hardware health monitoring for HYDRA experiments.
2
-
3
- Provides lightweight checks that the orchestrator runs before each
4
- experiment to avoid launching training into a degraded GPU state.
5
- """
6
- import os
7
-
8
- import torch
9
-
10
-
11
- def get_gpu_stats() -> dict:
12
- """Return current GPU memory statistics.
13
-
14
- Returns:
15
- Dict with keys: available (bool), and when available:
16
- name, memory_allocated_mb, memory_reserved_mb,
17
- max_memory_allocated_mb, memory_total_mb.
18
- """
19
- if not torch.cuda.is_available():
20
- return {"available": False}
21
-
22
- props = torch.cuda.get_device_properties(0)
23
- return {
24
- "available": True,
25
- "name": torch.cuda.get_device_name(0),
26
- "memory_allocated_mb": torch.cuda.memory_allocated(0) / (1024 * 1024),
27
- "memory_reserved_mb": torch.cuda.memory_reserved(0) / (1024 * 1024),
28
- "max_memory_allocated_mb": torch.cuda.max_memory_allocated(0) / (1024 * 1024),
29
- "memory_total_mb": props.total_mem / (1024 * 1024),
30
- }
31
-
32
-
33
- def check_health(
34
- vram_pressure_pct: float = 90.0,
35
- min_free_disk_gb: float = 1.0,
36
- ) -> tuple[bool, list[str]]:
37
- """Check GPU and disk health before launching an experiment.
38
-
39
- Args:
40
- vram_pressure_pct: Warn when GPU memory allocation exceeds this
41
- percentage of total VRAM.
42
- min_free_disk_gb: Warn when free disk space falls below this.
43
-
44
- Returns:
45
- Tuple of (healthy: bool, warnings: list[str]).
46
- ``healthy`` is True when there are no warnings.
47
- """
48
- warnings: list[str] = []
49
- stats = get_gpu_stats()
50
-
51
- if not stats["available"]:
52
- return False, ["No CUDA GPU available"]
53
-
54
- # Memory pressure check.
55
- used_pct = (
56
- stats["memory_allocated_mb"] / stats["memory_total_mb"] * 100
57
- if stats["memory_total_mb"] > 0
58
- else 0.0
59
- )
60
- if used_pct > vram_pressure_pct:
61
- warnings.append(
62
- f"GPU memory pressure: {used_pct:.1f}% allocated "
63
- f"({stats['memory_allocated_mb']:.0f} / {stats['memory_total_mb']:.0f} MB)"
64
- )
65
-
66
- # Disk space check.
67
- try:
68
- statvfs = os.statvfs(os.path.dirname(os.path.abspath(__file__)))
69
- free_gb = (statvfs.f_bavail * statvfs.f_frsize) / (1024**3)
70
- if free_gb < min_free_disk_gb:
71
- warnings.append(f"Low disk space: {free_gb:.2f} GB free")
72
- except (AttributeError, OSError):
73
- # os.statvfs not available on all platforms (e.g. Windows).
74
- pass
75
-
76
- return len(warnings) == 0, warnings
77
-
78
-
79
- def reset_peak_stats() -> None:
80
- """Reset GPU peak memory tracking for the next experiment.
81
-
82
- Should be called immediately before launching each training run so
83
- that peak_vram_mb reported in run.log reflects only that experiment.
84
- """
85
- if torch.cuda.is_available():
86
- torch.cuda.reset_peak_memory_stats()
 
1
+ """Hardware health monitoring for HYDRA experiments.
2
+
3
+ Provides lightweight checks that the orchestrator runs before each
4
+ experiment to avoid launching training into a degraded GPU state.
5
+ """
6
+ import os
7
+
8
+ import torch
9
+
10
+
11
+ def get_gpu_stats() -> dict:
12
+ """Return current GPU memory statistics.
13
+
14
+ Returns:
15
+ Dict with keys: available (bool), and when available:
16
+ name, memory_allocated_mb, memory_reserved_mb,
17
+ max_memory_allocated_mb, memory_total_mb.
18
+ """
19
+ if not torch.cuda.is_available():
20
+ return {"available": False}
21
+
22
+ props = torch.cuda.get_device_properties(0)
23
+ return {
24
+ "available": True,
25
+ "name": torch.cuda.get_device_name(0),
26
+ "memory_allocated_mb": torch.cuda.memory_allocated(0) / (1024 * 1024),
27
+ "memory_reserved_mb": torch.cuda.memory_reserved(0) / (1024 * 1024),
28
+ "max_memory_allocated_mb": torch.cuda.max_memory_allocated(0) / (1024 * 1024),
29
+ "memory_total_mb": props.total_mem / (1024 * 1024),
30
+ }
31
+
32
+
33
+ def check_health(
34
+ vram_pressure_pct: float = 90.0,
35
+ min_free_disk_gb: float = 1.0,
36
+ ) -> tuple[bool, list[str]]:
37
+ """Check GPU and disk health before launching an experiment.
38
+
39
+ Args:
40
+ vram_pressure_pct: Warn when GPU memory allocation exceeds this
41
+ percentage of total VRAM.
42
+ min_free_disk_gb: Warn when free disk space falls below this.
43
+
44
+ Returns:
45
+ Tuple of (healthy: bool, warnings: list[str]).
46
+ ``healthy`` is True when there are no warnings.
47
+ """
48
+ warnings: list[str] = []
49
+ stats = get_gpu_stats()
50
+
51
+ if not stats["available"]:
52
+ return False, ["No CUDA GPU available"]
53
+
54
+ # Memory pressure check.
55
+ used_pct = (
56
+ stats["memory_allocated_mb"] / stats["memory_total_mb"] * 100
57
+ if stats["memory_total_mb"] > 0
58
+ else 0.0
59
+ )
60
+ if used_pct > vram_pressure_pct:
61
+ warnings.append(
62
+ f"GPU memory pressure: {used_pct:.1f}% allocated "
63
+ f"({stats['memory_allocated_mb']:.0f} / {stats['memory_total_mb']:.0f} MB)"
64
+ )
65
+
66
+ # Disk space check.
67
+ try:
68
+ statvfs = os.statvfs(os.path.dirname(os.path.abspath(__file__)))
69
+ free_gb = (statvfs.f_bavail * statvfs.f_frsize) / (1024**3)
70
+ if free_gb < min_free_disk_gb:
71
+ warnings.append(f"Low disk space: {free_gb:.2f} GB free")
72
+ except (AttributeError, OSError):
73
+ # os.statvfs not available on all platforms (e.g. Windows).
74
+ pass
75
+
76
+ return len(warnings) == 0, warnings
77
+
78
+
79
+ def reset_peak_stats() -> None:
80
+ """Reset GPU peak memory tracking for the next experiment.
81
+
82
+ Should be called immediately before launching each training run so
83
+ that peak_vram_mb reported in run.log reflects only that experiment.
84
+ """
85
+ if torch.cuda.is_available():
86
+ torch.cuda.reset_peak_memory_stats()
overlay/harness/meta_agent.py CHANGED
@@ -1,139 +1,139 @@
1
- """Meta-agent: evolves program.md based on experiment history.
2
-
3
- Runs every ``meta_interval`` inner-loop experiments (configured in
4
- HarnessConfig). Reads the current research state from results.tsv,
5
- decides whether guidance is needed, and appends a directive to
6
- program.md. Any previous auto-generated directive is replaced so
7
- the file stays clean.
8
- """
9
- import os
10
-
11
- from harness.git_utils import REPO_DIR
12
- from harness.search_strategy import ResearchState, diagnose
13
-
14
- PROGRAM_PATH = os.path.join(REPO_DIR, "program.md")
15
- RESULTS_PATH = os.path.join(REPO_DIR, "results.tsv")
16
-
17
- # Sentinel that marks auto-generated content so it can be cleanly replaced.
18
- _DIRECTIVE_MARKER = "## Meta-Agent Directive (auto-generated)"
19
-
20
-
21
- def generate_directive(state: ResearchState) -> str | None:
22
- """Generate a directive string to append to program.md, or None.
23
-
24
- A directive is only produced when the research state is not EXPLORING
25
- (i.e., something needs to change).
26
-
27
- Args:
28
- state: Current ResearchState diagnosis.
29
-
30
- Returns:
31
- Formatted directive string, or None when no change is needed.
32
- """
33
- if state.label == "EXPLORING":
34
- return None
35
-
36
- if state.label == "BROKEN":
37
- return (
38
- f"\n{_DIRECTIVE_MARKER}\n"
39
- f"ALERT: Crash rate is {state.crash_rate:.0%} in the recent window. "
40
- "Revert to the last stable commit. Reduce model complexity before "
41
- "proposing further changes. Suggested actions:\n"
42
- "- Reduce d_model or n_layer\n"
43
- "- Reduce batch_size\n"
44
- "- Disable experimental modules (Engram, mHC, Hestia) one at a time\n"
45
- )
46
-
47
- if state.label == "STUCK":
48
- stale = state.total_experiments - state.last_improvement_at
49
- return (
50
- f"\n{_DIRECTIVE_MARKER}\n"
51
- f"ALERT: No improvement for {stale} experiments "
52
- f"(best_bpb={state.best_bpb:.6f}). "
53
- "Apply BOLD changes for the next 5 experiments:\n"
54
- "- Dramatically change d_model or n_layer (2× or ½)\n"
55
- "- Toggle Engram or mHC on/off entirely\n"
56
- "- Change optimizer hyperparameters by 3–5×\n"
57
- "- Temporarily accept results within 0.5% of baseline\n"
58
- )
59
-
60
- if state.label == "EXPLOITING":
61
- return (
62
- f"\n{_DIRECTIVE_MARKER}\n"
63
- "Search is converging too early. Inject diversity:\n"
64
- "- If recent experiments tune LR, try architecture changes instead\n"
65
- "- If tuning architecture, try optimizer or regularisation changes\n"
66
- "- Try removing complexity (simplification wins are valuable)\n"
67
- "- Explore a subsystem not touched in the last 10 experiments\n"
68
- )
69
-
70
- return None
71
-
72
-
73
- def _strip_previous_directive(content: str) -> str:
74
- """Remove any prior auto-generated directive block from content.
75
-
76
- Args:
77
- content: Full text of program.md.
78
-
79
- Returns:
80
- Content with any previous directive stripped and trailing
81
- whitespace normalised.
82
- """
83
- if _DIRECTIVE_MARKER in content:
84
- content = content[: content.index(_DIRECTIVE_MARKER)].rstrip() + "\n"
85
- return content
86
-
87
-
88
- def run_meta_iteration(
89
- program_path: str = PROGRAM_PATH,
90
- results_path: str = RESULTS_PATH,
91
- ) -> dict:
92
- """Run one meta-agent iteration.
93
-
94
- Diagnoses the current research state and optionally rewrites
95
- program.md with a new directive.
96
-
97
- Args:
98
- program_path: Path to program.md.
99
- results_path: Path to results.tsv.
100
-
101
- Returns:
102
- Summary dict with keys: state, total_experiments, best_bpb,
103
- crash_rate, changed, and optionally directive.
104
- """
105
- state = diagnose(results_path)
106
-
107
- summary: dict = {
108
- "state": state.label,
109
- "total_experiments": state.total_experiments,
110
- "best_bpb": state.best_bpb,
111
- "crash_rate": state.crash_rate,
112
- "changed": False,
113
- }
114
-
115
- directive = generate_directive(state)
116
- if directive is None:
117
- return summary
118
-
119
- try:
120
- with open(program_path) as fh:
121
- content = fh.read()
122
- except FileNotFoundError:
123
- content = ""
124
-
125
- content = _strip_previous_directive(content)
126
- content = content + "\n" + directive
127
-
128
- tmp_path = program_path + ".tmp"
129
- try:
130
- with open(tmp_path, "w") as fh:
131
- fh.write(content)
132
- os.replace(tmp_path, program_path) # atomic on POSIX
133
- finally:
134
- if os.path.exists(tmp_path):
135
- os.unlink(tmp_path)
136
-
137
- summary["changed"] = True
138
- summary["directive"] = directive.strip()
139
- return summary
 
1
+ """Meta-agent: evolves program.md based on experiment history.
2
+
3
+ Runs every ``meta_interval`` inner-loop experiments (configured in
4
+ HarnessConfig). Reads the current research state from results.tsv,
5
+ decides whether guidance is needed, and appends a directive to
6
+ program.md. Any previous auto-generated directive is replaced so
7
+ the file stays clean.
8
+ """
9
+ import os
10
+
11
+ from harness.git_utils import REPO_DIR
12
+ from harness.search_strategy import ResearchState, diagnose
13
+
14
+ PROGRAM_PATH = os.path.join(REPO_DIR, "program.md")
15
+ RESULTS_PATH = os.path.join(REPO_DIR, "results.tsv")
16
+
17
+ # Sentinel that marks auto-generated content so it can be cleanly replaced.
18
+ _DIRECTIVE_MARKER = "## Meta-Agent Directive (auto-generated)"
19
+
20
+
21
+ def generate_directive(state: ResearchState) -> str | None:
22
+ """Generate a directive string to append to program.md, or None.
23
+
24
+ A directive is only produced when the research state is not EXPLORING
25
+ (i.e., something needs to change).
26
+
27
+ Args:
28
+ state: Current ResearchState diagnosis.
29
+
30
+ Returns:
31
+ Formatted directive string, or None when no change is needed.
32
+ """
33
+ if state.label == "EXPLORING":
34
+ return None
35
+
36
+ if state.label == "BROKEN":
37
+ return (
38
+ f"\n{_DIRECTIVE_MARKER}\n"
39
+ f"ALERT: Crash rate is {state.crash_rate:.0%} in the recent window. "
40
+ "Revert to the last stable commit. Reduce model complexity before "
41
+ "proposing further changes. Suggested actions:\n"
42
+ "- Reduce d_model or n_layer\n"
43
+ "- Reduce batch_size\n"
44
+ "- Disable experimental modules (Engram, mHC, Hestia) one at a time\n"
45
+ )
46
+
47
+ if state.label == "STUCK":
48
+ stale = state.total_experiments - state.last_improvement_at
49
+ return (
50
+ f"\n{_DIRECTIVE_MARKER}\n"
51
+ f"ALERT: No improvement for {stale} experiments "
52
+ f"(best_bpb={state.best_bpb:.6f}). "
53
+ "Apply BOLD changes for the next 5 experiments:\n"
54
+ "- Dramatically change d_model or n_layer (2× or ½)\n"
55
+ "- Toggle Engram or mHC on/off entirely\n"
56
+ "- Change optimizer hyperparameters by 3–5×\n"
57
+ "- Temporarily accept results within 0.5% of baseline\n"
58
+ )
59
+
60
+ if state.label == "EXPLOITING":
61
+ return (
62
+ f"\n{_DIRECTIVE_MARKER}\n"
63
+ "Search is converging too early. Inject diversity:\n"
64
+ "- If recent experiments tune LR, try architecture changes instead\n"
65
+ "- If tuning architecture, try optimizer or regularisation changes\n"
66
+ "- Try removing complexity (simplification wins are valuable)\n"
67
+ "- Explore a subsystem not touched in the last 10 experiments\n"
68
+ )
69
+
70
+ return None
71
+
72
+
73
+ def _strip_previous_directive(content: str) -> str:
74
+ """Remove any prior auto-generated directive block from content.
75
+
76
+ Args:
77
+ content: Full text of program.md.
78
+
79
+ Returns:
80
+ Content with any previous directive stripped and trailing
81
+ whitespace normalised.
82
+ """
83
+ if _DIRECTIVE_MARKER in content:
84
+ content = content[: content.index(_DIRECTIVE_MARKER)].rstrip() + "\n"
85
+ return content
86
+
87
+
88
+ def run_meta_iteration(
89
+ program_path: str = PROGRAM_PATH,
90
+ results_path: str = RESULTS_PATH,
91
+ ) -> dict:
92
+ """Run one meta-agent iteration.
93
+
94
+ Diagnoses the current research state and optionally rewrites
95
+ program.md with a new directive.
96
+
97
+ Args:
98
+ program_path: Path to program.md.
99
+ results_path: Path to results.tsv.
100
+
101
+ Returns:
102
+ Summary dict with keys: state, total_experiments, best_bpb,
103
+ crash_rate, changed, and optionally directive.
104
+ """
105
+ state = diagnose(results_path)
106
+
107
+ summary: dict = {
108
+ "state": state.label,
109
+ "total_experiments": state.total_experiments,
110
+ "best_bpb": state.best_bpb,
111
+ "crash_rate": state.crash_rate,
112
+ "changed": False,
113
+ }
114
+
115
+ directive = generate_directive(state)
116
+ if directive is None:
117
+ return summary
118
+
119
+ try:
120
+ with open(program_path) as fh:
121
+ content = fh.read()
122
+ except FileNotFoundError:
123
+ content = ""
124
+
125
+ content = _strip_previous_directive(content)
126
+ content = content + "\n" + directive
127
+
128
+ tmp_path = program_path + ".tmp"
129
+ try:
130
+ with open(tmp_path, "w") as fh:
131
+ fh.write(content)
132
+ os.replace(tmp_path, program_path) # atomic on POSIX
133
+ finally:
134
+ if os.path.exists(tmp_path):
135
+ os.unlink(tmp_path)
136
+
137
+ summary["changed"] = True
138
+ summary["directive"] = directive.strip()
139
+ return summary
overlay/harness/orchestrator.py CHANGED
@@ -1,296 +1,293 @@
1
- """HYDRA Orchestrator: main loop for autonomous research.
2
-
3
- Usage::
4
-
5
- python -m harness.orchestrator [--meta-interval N] [--max-experiments N]
6
-
7
- Loop:
8
- 1. Read current state (branch, results.tsv, program.md)
9
- 2. [Architect Agent] proposes and applies changes to train.py (external)
10
- 3. Git commit the changes
11
- 4. Run training: ``uv run train.py`` captured to run.log
12
- 5. [Eval Agent] extract metrics from run.log
13
- 6. Keep or discard based on val_bpb + secondary metric gates
14
- 7. Log to results.tsv
15
- 8. Every ``meta_interval`` experiments: [Meta Agent] evolves program.md
16
- 9. Repeat
17
-
18
- The orchestrator intentionally does NOT modify train.py itself -- it
19
- provides the infrastructure ("rails") that the autoresearch loop runs on.
20
- """
21
- import argparse
22
- import csv
23
  import os
24
  import subprocess
25
  import time
26
 
27
- from configs.harness_config import HarnessConfig
28
  from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
29
- from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
30
- from harness.health_monitor import check_health, reset_peak_stats
31
- from harness.meta_agent import run_meta_iteration
32
- from harness.search_strategy import diagnose
33
-
34
- # ---------------------------------------------------------------------------
35
- # Paths
36
- # ---------------------------------------------------------------------------
37
-
38
- RESULTS_FILE = os.path.join(REPO_DIR, "results.tsv")
39
- RUN_LOG = os.path.join(REPO_DIR, "run.log")
40
-
41
- _TSV_HEADER = "commit\tval_bpb\tmemory_gb\tstatus\tdescription\n"
42
-
43
-
44
- # ---------------------------------------------------------------------------
45
- # TSV helpers
46
- # ---------------------------------------------------------------------------
47
-
48
-
49
- def init_results_tsv() -> None:
50
- """Create results.tsv with header row if it does not yet exist."""
51
- if not os.path.exists(RESULTS_FILE):
52
- with open(RESULTS_FILE, "w") as fh:
53
- fh.write(_TSV_HEADER)
54
-
55
-
56
- def log_result(
57
- commit: str,
58
- val_bpb: float,
59
- memory_gb: float,
60
- status: str,
61
- description: str,
62
- ) -> None:
63
- """Append one row to results.tsv.
64
-
65
- Args:
66
- commit: Short git hash for this experiment.
67
- val_bpb: Validation bits-per-byte (0.0 for crashes).
68
- memory_gb: Peak VRAM usage in gigabytes.
69
- status: One of keep / discard / crash / timeout.
70
- description: Short human-readable description.
71
- """
72
- with open(RESULTS_FILE, "a") as fh:
73
- fh.write(
74
- f"{commit}\t{val_bpb:.6f}\t{memory_gb:.2f}\t{status}\t{description}\n"
75
- )
76
-
77
-
78
- def count_experiments() -> int:
79
- """Count the number of experiment rows in results.tsv.
80
-
81
- Returns:
82
- Row count excluding the header line (0 when file does not exist).
83
- """
84
- if not os.path.exists(RESULTS_FILE):
85
- return 0
86
- with open(RESULTS_FILE) as fh:
87
- return max(0, sum(1 for _ in fh) - 1)
88
-
89
-
90
- def _load_best_bpb() -> float:
91
- """Scan results.tsv for the best (lowest positive) val_bpb seen so far.
92
-
93
- Returns:
94
- Best val_bpb, or ``float("inf")`` when no valid result exists.
95
- """
96
- if not os.path.exists(RESULTS_FILE):
97
- return float("inf")
98
- best = float("inf")
99
- with open(RESULTS_FILE) as fh:
100
- reader = csv.DictReader(fh, delimiter="\t")
101
- for row in reader:
102
- try:
103
- bpb = float(row.get("val_bpb", "0") or "0")
104
- except ValueError:
105
- continue
106
- if 0 < bpb < best:
107
- best = bpb
108
- return best
109
-
110
-
111
- # ---------------------------------------------------------------------------
112
- # Experiment execution
113
- # ---------------------------------------------------------------------------
114
-
115
-
116
- def run_experiment(timeout: int = 600) -> str:
117
- """Launch ``uv run train.py`` and capture all output to run.log.
118
-
119
- Args:
120
- timeout: Kill the process after this many seconds.
121
-
122
- Returns:
123
- One of ``"ok"``, ``"timeout"``, or ``"error"``.
124
- """
125
- try:
126
- with open(RUN_LOG, "w") as log_file:
127
- proc = subprocess.run(
128
- ["uv", "run", "train.py"],
129
- cwd=REPO_DIR,
130
- stdout=log_file,
131
- stderr=subprocess.STDOUT,
132
- timeout=timeout,
133
- )
134
- return "ok" if proc.returncode == 0 else "error"
135
- except subprocess.TimeoutExpired:
136
- return "timeout"
137
- except Exception as exc: # noqa: BLE001
138
- with open(RUN_LOG, "a") as log_file:
139
- log_file.write(f"\nOrchestrator error: {exc}\n")
140
- return "error"
141
-
142
-
143
- # ---------------------------------------------------------------------------
144
- # Main loop
145
- # ---------------------------------------------------------------------------
146
-
147
-
148
  def run_loop(
149
  meta_interval: int = 20,
150
  max_experiments: int | None = None,
151
  experiment_timeout: int = 600,
152
- secondary_gates: dict[str, dict[str, float]] | None = None,
153
  ) -> None:
154
- """Run the HYDRA autoresearch loop.
155
-
156
- This function runs indefinitely (or until ``max_experiments`` is reached
157
- or the user interrupts with Ctrl-C).
158
-
159
- Args:
160
- meta_interval: Run the meta-agent every N experiments.
161
- max_experiments: Hard stop after this many experiments (None = infinite).
162
- experiment_timeout: Seconds before a training run is killed.
163
- secondary_gates: Optional gate thresholds forwarded to
164
- :func:`~harness.eval_agent.should_keep`.
165
- """
166
  init_results_tsv()
167
- if secondary_gates is None:
168
- secondary_gates = HarnessConfig().to_secondary_gates()
169
  best_bpb = _load_best_bpb()
170
- experiment_num = count_experiments()
171
-
172
- print(
173
- f"HYDRA Orchestrator starting. "
174
- f"Experiments so far: {experiment_num}, Best BPB: {best_bpb:.6f}"
175
- )
176
-
177
- while max_experiments is None or experiment_num < max_experiments:
178
- experiment_num += 1
179
-
180
- # ------------------------------------------------------------------
181
- # Pre-flight health check
182
- # ------------------------------------------------------------------
183
- healthy, hw_warnings = check_health()
184
- if hw_warnings:
185
- print(f" [health] {hw_warnings}")
186
-
187
- # ------------------------------------------------------------------
188
- # Periodic meta-agent update
189
- # ------------------------------------------------------------------
190
- if experiment_num > 1 and experiment_num % meta_interval == 0:
191
- print(f"\n=== Meta-agent iteration at experiment {experiment_num} ===")
192
- meta_result = run_meta_iteration()
193
- print(
194
- f" state={meta_result['state']} "
195
- f"best_bpb={meta_result['best_bpb']:.6f} "
196
- f"changed={meta_result['changed']}"
197
- )
198
- if meta_result.get("directive"):
199
- print(f" directive: {meta_result['directive'][:120]}")
200
-
201
- # ------------------------------------------------------------------
202
- # Record baseline commit so we can reset on failure / discard
203
- # ------------------------------------------------------------------
204
- pre_commit = current_commit_short()
205
-
206
- # ------------------------------------------------------------------
207
- # Run experiment
208
- # ------------------------------------------------------------------
209
- print(f"\n--- Experiment {experiment_num} ---")
210
- reset_peak_stats()
211
- t0 = time.time()
212
- run_status = run_experiment(timeout=experiment_timeout)
213
- elapsed = time.time() - t0
214
- print(f" run_status={run_status} elapsed={elapsed:.1f}s")
215
-
216
- # ------------------------------------------------------------------
217
- # Parse results
218
- # ------------------------------------------------------------------
219
- result: ExperimentResult = parse_run_log(RUN_LOG)
220
-
221
- if result.crashed or run_status != "ok":
222
- commit = current_commit_short()
223
- err_short = (
224
- "timeout"
225
- if run_status == "timeout"
226
- else result.error_message[:80].replace("\n", " ")
227
- )
228
- log_result(commit, 0.0, 0.0, "crash", err_short)
229
- print(f" CRASH: {err_short}")
230
- reset_to(pre_commit)
231
- continue
232
-
233
- # ------------------------------------------------------------------
234
- # Secondary alarms (non-blocking -- logged but do not abort)
235
- # ------------------------------------------------------------------
236
- alarms = check_secondary_alarms(result)
237
- if alarms:
238
- for alarm in alarms:
239
- print(f" [alarm] {alarm}")
240
-
241
- # ------------------------------------------------------------------
242
- # Keep / discard
243
- # ------------------------------------------------------------------
244
- keep, reason = should_keep(result, best_bpb, gates=secondary_gates)
245
- commit = current_commit_short()
246
- memory_gb = result.peak_vram_mb / 1024.0
247
-
248
- if keep:
249
- best_bpb = result.val_bpb
250
- description = f"val_bpb improved to {result.val_bpb:.6f}"
251
- log_result(commit, result.val_bpb, memory_gb, "keep", description)
252
- print(f" KEEP: val_bpb={result.val_bpb:.6f} (new best)")
253
- else:
254
- description = f"{reason} val_bpb={result.val_bpb:.6f}"
255
- log_result(commit, result.val_bpb, memory_gb, "discard", description)
256
- print(f" DISCARD: val_bpb={result.val_bpb:.6f} ({reason})")
257
- reset_to(pre_commit)
258
-
259
- print(f"\nHYDRA finished after {experiment_num} experiments. Best BPB: {best_bpb:.6f}")
260
-
261
-
262
- # ---------------------------------------------------------------------------
263
- # CLI entry point
264
- # ---------------------------------------------------------------------------
265
-
266
-
267
- if __name__ == "__main__":
268
- parser = argparse.ArgumentParser(description="HYDRA Autoresearch Orchestrator")
269
- parser.add_argument(
270
- "--meta-interval",
271
- type=int,
272
- default=20,
273
- help="Run meta-agent every N experiments (default: 20)",
274
- )
275
- parser.add_argument(
276
- "--max-experiments",
277
- type=int,
278
- default=None,
279
- help="Stop after N experiments; omit for infinite (default: infinite)",
280
- )
281
- parser.add_argument(
282
- "--experiment-timeout",
283
- type=int,
284
- default=600,
285
- help="Kill training run after N seconds (default: 600)",
286
- )
287
- args = parser.parse_args()
288
-
289
- try:
290
- run_loop(
291
- meta_interval=args.meta_interval,
292
- max_experiments=args.max_experiments,
293
- experiment_timeout=args.experiment_timeout,
294
- )
295
- except KeyboardInterrupt:
296
- print("\nOrchestrator stopped by user.")
 
1
+ """HYDRA Orchestrator: main loop for autonomous research.
2
+
3
+ Usage::
4
+
5
+ python -m harness.orchestrator [--meta-interval N] [--max-experiments N]
6
+
7
+ Loop:
8
+ 1. Read current state (branch, results.tsv, program.md)
9
+ 2. [Architect Agent] proposes and applies changes to train.py (external)
10
+ 3. Git commit the changes
11
+ 4. Run training: ``uv run train.py`` captured to run.log
12
+ 5. [Eval Agent] extract metrics from run.log
13
+ 6. Keep or discard based on val_bpb + secondary metric gates
14
+ 7. Log to results.tsv
15
+ 8. Every ``meta_interval`` experiments: [Meta Agent] evolves program.md
16
+ 9. Repeat
17
+
18
+ The orchestrator intentionally does NOT modify train.py itself -- it
19
+ provides the infrastructure ("rails") that the autoresearch loop runs on.
20
+ """
21
+ import argparse
22
+ import csv
23
  import os
24
  import subprocess
25
  import time
26
 
 
27
  from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
28
+ from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
29
+ from harness.health_monitor import check_health, reset_peak_stats
30
+ from harness.meta_agent import run_meta_iteration
31
+ from harness.search_strategy import diagnose
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Paths
35
+ # ---------------------------------------------------------------------------
36
+
37
+ RESULTS_FILE = os.path.join(REPO_DIR, "results.tsv")
38
+ RUN_LOG = os.path.join(REPO_DIR, "run.log")
39
+
40
+ _TSV_HEADER = "commit\tval_bpb\tmemory_gb\tstatus\tdescription\n"
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # TSV helpers
45
+ # ---------------------------------------------------------------------------
46
+
47
+
48
+ def init_results_tsv() -> None:
49
+ """Create results.tsv with header row if it does not yet exist."""
50
+ if not os.path.exists(RESULTS_FILE):
51
+ with open(RESULTS_FILE, "w") as fh:
52
+ fh.write(_TSV_HEADER)
53
+
54
+
55
+ def log_result(
56
+ commit: str,
57
+ val_bpb: float,
58
+ memory_gb: float,
59
+ status: str,
60
+ description: str,
61
+ ) -> None:
62
+ """Append one row to results.tsv.
63
+
64
+ Args:
65
+ commit: Short git hash for this experiment.
66
+ val_bpb: Validation bits-per-byte (0.0 for crashes).
67
+ memory_gb: Peak VRAM usage in gigabytes.
68
+ status: One of keep / discard / crash / timeout.
69
+ description: Short human-readable description.
70
+ """
71
+ with open(RESULTS_FILE, "a") as fh:
72
+ fh.write(
73
+ f"{commit}\t{val_bpb:.6f}\t{memory_gb:.2f}\t{status}\t{description}\n"
74
+ )
75
+
76
+
77
+ def count_experiments() -> int:
78
+ """Count the number of experiment rows in results.tsv.
79
+
80
+ Returns:
81
+ Row count excluding the header line (0 when file does not exist).
82
+ """
83
+ if not os.path.exists(RESULTS_FILE):
84
+ return 0
85
+ with open(RESULTS_FILE) as fh:
86
+ return max(0, sum(1 for _ in fh) - 1)
87
+
88
+
89
+ def _load_best_bpb() -> float:
90
+ """Scan results.tsv for the best (lowest positive) val_bpb seen so far.
91
+
92
+ Returns:
93
+ Best val_bpb, or ``float("inf")`` when no valid result exists.
94
+ """
95
+ if not os.path.exists(RESULTS_FILE):
96
+ return float("inf")
97
+ best = float("inf")
98
+ with open(RESULTS_FILE) as fh:
99
+ reader = csv.DictReader(fh, delimiter="\t")
100
+ for row in reader:
101
+ try:
102
+ bpb = float(row.get("val_bpb", "0") or "0")
103
+ except ValueError:
104
+ continue
105
+ if 0 < bpb < best:
106
+ best = bpb
107
+ return best
108
+
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # Experiment execution
112
+ # ---------------------------------------------------------------------------
113
+
114
+
115
+ def run_experiment(timeout: int = 600) -> str:
116
+ """Launch ``uv run train.py`` and capture all output to run.log.
117
+
118
+ Args:
119
+ timeout: Kill the process after this many seconds.
120
+
121
+ Returns:
122
+ One of ``"ok"``, ``"timeout"``, or ``"error"``.
123
+ """
124
+ try:
125
+ with open(RUN_LOG, "w") as log_file:
126
+ proc = subprocess.run(
127
+ ["uv", "run", "train.py"],
128
+ cwd=REPO_DIR,
129
+ stdout=log_file,
130
+ stderr=subprocess.STDOUT,
131
+ timeout=timeout,
132
+ )
133
+ return "ok" if proc.returncode == 0 else "error"
134
+ except subprocess.TimeoutExpired:
135
+ return "timeout"
136
+ except Exception as exc: # noqa: BLE001
137
+ with open(RUN_LOG, "a") as log_file:
138
+ log_file.write(f"\nOrchestrator error: {exc}\n")
139
+ return "error"
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # Main loop
144
+ # ---------------------------------------------------------------------------
145
+
146
+
147
  def run_loop(
148
  meta_interval: int = 20,
149
  max_experiments: int | None = None,
150
  experiment_timeout: int = 600,
151
+ secondary_gates: dict | None = None,
152
  ) -> None:
153
+ """Run the HYDRA autoresearch loop.
154
+
155
+ This function runs indefinitely (or until ``max_experiments`` is reached
156
+ or the user interrupts with Ctrl-C).
157
+
158
+ Args:
159
+ meta_interval: Run the meta-agent every N experiments.
160
+ max_experiments: Hard stop after this many experiments (None = infinite).
161
+ experiment_timeout: Seconds before a training run is killed.
162
+ secondary_gates: Optional gate thresholds forwarded to
163
+ :func:`~harness.eval_agent.should_keep`.
164
+ """
165
  init_results_tsv()
 
 
166
  best_bpb = _load_best_bpb()
167
+ experiment_num = count_experiments()
168
+
169
+ print(
170
+ f"HYDRA Orchestrator starting. "
171
+ f"Experiments so far: {experiment_num}, Best BPB: {best_bpb:.6f}"
172
+ )
173
+
174
+ while max_experiments is None or experiment_num < max_experiments:
175
+ experiment_num += 1
176
+
177
+ # ------------------------------------------------------------------
178
+ # Pre-flight health check
179
+ # ------------------------------------------------------------------
180
+ healthy, hw_warnings = check_health()
181
+ if hw_warnings:
182
+ print(f" [health] {hw_warnings}")
183
+
184
+ # ------------------------------------------------------------------
185
+ # Periodic meta-agent update
186
+ # ------------------------------------------------------------------
187
+ if experiment_num > 1 and experiment_num % meta_interval == 0:
188
+ print(f"\n=== Meta-agent iteration at experiment {experiment_num} ===")
189
+ meta_result = run_meta_iteration()
190
+ print(
191
+ f" state={meta_result['state']} "
192
+ f"best_bpb={meta_result['best_bpb']:.6f} "
193
+ f"changed={meta_result['changed']}"
194
+ )
195
+ if meta_result.get("directive"):
196
+ print(f" directive: {meta_result['directive'][:120]}")
197
+
198
+ # ------------------------------------------------------------------
199
+ # Record baseline commit so we can reset on failure / discard
200
+ # ------------------------------------------------------------------
201
+ pre_commit = current_commit_short()
202
+
203
+ # ------------------------------------------------------------------
204
+ # Run experiment
205
+ # ------------------------------------------------------------------
206
+ print(f"\n--- Experiment {experiment_num} ---")
207
+ reset_peak_stats()
208
+ t0 = time.time()
209
+ run_status = run_experiment(timeout=experiment_timeout)
210
+ elapsed = time.time() - t0
211
+ print(f" run_status={run_status} elapsed={elapsed:.1f}s")
212
+
213
+ # ------------------------------------------------------------------
214
+ # Parse results
215
+ # ------------------------------------------------------------------
216
+ result: ExperimentResult = parse_run_log(RUN_LOG)
217
+
218
+ if result.crashed or run_status != "ok":
219
+ commit = current_commit_short()
220
+ err_short = (
221
+ "timeout"
222
+ if run_status == "timeout"
223
+ else result.error_message[:80].replace("\n", " ")
224
+ )
225
+ log_result(commit, 0.0, 0.0, "crash", err_short)
226
+ print(f" CRASH: {err_short}")
227
+ reset_to(pre_commit)
228
+ continue
229
+
230
+ # ------------------------------------------------------------------
231
+ # Secondary alarms (non-blocking -- logged but do not abort)
232
+ # ------------------------------------------------------------------
233
+ alarms = check_secondary_alarms(result)
234
+ if alarms:
235
+ for alarm in alarms:
236
+ print(f" [alarm] {alarm}")
237
+
238
+ # ------------------------------------------------------------------
239
+ # Keep / discard
240
+ # ------------------------------------------------------------------
241
+ keep, reason = should_keep(result, best_bpb, gates=secondary_gates)
242
+ commit = current_commit_short()
243
+ memory_gb = result.peak_vram_mb / 1024.0
244
+
245
+ if keep:
246
+ best_bpb = result.val_bpb
247
+ description = f"val_bpb improved to {result.val_bpb:.6f}"
248
+ log_result(commit, result.val_bpb, memory_gb, "keep", description)
249
+ print(f" KEEP: val_bpb={result.val_bpb:.6f} (new best)")
250
+ else:
251
+ description = f"{reason} val_bpb={result.val_bpb:.6f}"
252
+ log_result(commit, result.val_bpb, memory_gb, "discard", description)
253
+ print(f" DISCARD: val_bpb={result.val_bpb:.6f} ({reason})")
254
+ reset_to(pre_commit)
255
+
256
+ print(f"\nHYDRA finished after {experiment_num} experiments. Best BPB: {best_bpb:.6f}")
257
+
258
+
259
+ # ---------------------------------------------------------------------------
260
+ # CLI entry point
261
+ # ---------------------------------------------------------------------------
262
+
263
+
264
+ if __name__ == "__main__":
265
+ parser = argparse.ArgumentParser(description="HYDRA Autoresearch Orchestrator")
266
+ parser.add_argument(
267
+ "--meta-interval",
268
+ type=int,
269
+ default=20,
270
+ help="Run meta-agent every N experiments (default: 20)",
271
+ )
272
+ parser.add_argument(
273
+ "--max-experiments",
274
+ type=int,
275
+ default=None,
276
+ help="Stop after N experiments; omit for infinite (default: infinite)",
277
+ )
278
+ parser.add_argument(
279
+ "--experiment-timeout",
280
+ type=int,
281
+ default=600,
282
+ help="Kill training run after N seconds (default: 600)",
283
+ )
284
+ args = parser.parse_args()
285
+
286
+ try:
287
+ run_loop(
288
+ meta_interval=args.meta_interval,
289
+ max_experiments=args.max_experiments,
290
+ experiment_timeout=args.experiment_timeout,
291
+ )
292
+ except KeyboardInterrupt:
293
+ print("\nOrchestrator stopped by user.")
overlay/harness/search_strategy.py CHANGED
@@ -1,153 +1,153 @@
1
- """Search strategy for HYDRA's meta-evolution loop.
2
-
3
- Reads results.tsv and diagnoses the current research state as one of:
4
- EXPLORING -- active improvement trend with diverse experiments
5
- EXPLOITING -- narrowing in on a local optimum (low diversity)
6
- STUCK -- no improvement for >= stuck_threshold experiments
7
- BROKEN -- crash rate exceeds crash_threshold
8
- """
9
- import csv
10
- import os
11
- from dataclasses import dataclass
12
-
13
-
14
- @dataclass
15
- class ResearchState:
16
- """Diagnosis of the current research trajectory.
17
-
18
- Attributes:
19
- label: One of EXPLORING, EXPLOITING, STUCK, BROKEN.
20
- trend_improving: True when the second half of the recent window is
21
- better (lower BPB) than the first half.
22
- experiment_diversity: Rough 0–1 score based on unique description
23
- prefixes in the recent window.
24
- crash_rate: Fraction of recent experiments that crashed.
25
- best_bpb: Lowest val_bpb seen across all experiments.
26
- last_improvement_at: Ordinal of the experiment that set best_bpb.
27
- total_experiments: Total rows in results.tsv (excluding header).
28
- """
29
-
30
- label: str
31
- trend_improving: bool
32
- experiment_diversity: float
33
- crash_rate: float
34
- best_bpb: float
35
- last_improvement_at: int
36
- total_experiments: int
37
-
38
-
39
- def diagnose(
40
- results_path: str,
41
- window: int = 20,
42
- stuck_threshold: int = 10,
43
- crash_threshold: float = 0.5,
44
- ) -> ResearchState:
45
- """Diagnose current research state from results.tsv.
46
-
47
- Args:
48
- results_path: Path to the tab-separated results file.
49
- window: Number of recent experiments to consider for trend/diversity.
50
- stuck_threshold: Experiments without improvement before labelling STUCK.
51
- crash_threshold: Crash fraction above which state becomes BROKEN.
52
-
53
- Returns:
54
- ResearchState with diagnosis label and supporting statistics.
55
- """
56
- if not os.path.exists(results_path):
57
- return ResearchState(
58
- label="EXPLORING",
59
- trend_improving=False,
60
- experiment_diversity=0.0,
61
- crash_rate=0.0,
62
- best_bpb=float("inf"),
63
- last_improvement_at=0,
64
- total_experiments=0,
65
- )
66
-
67
- rows: list[dict] = []
68
- with open(results_path) as fh:
69
- reader = csv.DictReader(fh, delimiter="\t")
70
- for row in reader:
71
- rows.append(row)
72
-
73
- if not rows:
74
- return ResearchState(
75
- label="EXPLORING",
76
- trend_improving=False,
77
- experiment_diversity=0.0,
78
- crash_rate=0.0,
79
- best_bpb=float("inf"),
80
- last_improvement_at=0,
81
- total_experiments=0,
82
- )
83
-
84
- total = len(rows)
85
- recent = rows[-window:]
86
-
87
- # Crash rate in the recent window.
88
- crashes = sum(1 for r in recent if r.get("status") == "crash")
89
- crash_rate = crashes / len(recent) if recent else 0.0
90
-
91
- # Best BPB overall and which experiment achieved it.
92
- best_bpb = float("inf")
93
- last_improvement_at = 0
94
- for i, row in enumerate(rows):
95
- try:
96
- bpb = float(row.get("val_bpb", "0") or "0")
97
- except ValueError:
98
- continue
99
- if bpb > 0 and bpb < best_bpb:
100
- best_bpb = bpb
101
- last_improvement_at = i + 1
102
-
103
- # Trend: is the second half of the recent window better than the first?
104
- valid_bpbs = [
105
- float(r.get("val_bpb", "0") or "0")
106
- for r in recent
107
- if float(r.get("val_bpb", "0") or "0") > 0
108
- ]
109
- trend_improving = False
110
- if len(valid_bpbs) >= 4:
111
- mid = len(valid_bpbs) // 2
112
- first_half_mean = sum(valid_bpbs[:mid]) / mid
113
- second_half_mean = sum(valid_bpbs[mid:]) / (len(valid_bpbs) - mid)
114
- trend_improving = second_half_mean < first_half_mean
115
-
116
- # Diversity: fraction of unique description prefixes (first 20 chars).
117
- descriptions = {r.get("description", "")[:20] for r in recent}
118
- diversity = min(1.0, len(descriptions) / max(1, len(recent)))
119
-
120
- # Classify state.
121
- stale = total - last_improvement_at
122
- if crash_rate > crash_threshold:
123
- label = "BROKEN"
124
- elif stale >= stuck_threshold:
125
- label = "STUCK"
126
- elif trend_improving and diversity > 0.3:
127
- label = "EXPLORING"
128
- else:
129
- label = "EXPLOITING"
130
-
131
- return ResearchState(
132
- label=label,
133
- trend_improving=trend_improving,
134
- experiment_diversity=diversity,
135
- crash_rate=crash_rate,
136
- best_bpb=best_bpb,
137
- last_improvement_at=last_improvement_at,
138
- total_experiments=total,
139
- )
140
-
141
-
142
- def should_explore(results_path: str, n: int = 10) -> bool:
143
- """Return True when no improvement has been seen in the last N experiments.
144
-
145
- Args:
146
- results_path: Path to results.tsv.
147
- n: Look-back window for improvement check.
148
-
149
- Returns:
150
- True if the research loop should try bolder mutations.
151
- """
152
- state = diagnose(results_path, window=n, stuck_threshold=n)
153
- return state.label in ("STUCK", "BROKEN")
 
1
+ """Search strategy for HYDRA's meta-evolution loop.
2
+
3
+ Reads results.tsv and diagnoses the current research state as one of:
4
+ EXPLORING -- active improvement trend with diverse experiments
5
+ EXPLOITING -- narrowing in on a local optimum (low diversity)
6
+ STUCK -- no improvement for >= stuck_threshold experiments
7
+ BROKEN -- crash rate exceeds crash_threshold
8
+ """
9
+ import csv
10
+ import os
11
+ from dataclasses import dataclass
12
+
13
+
14
+ @dataclass
15
+ class ResearchState:
16
+ """Diagnosis of the current research trajectory.
17
+
18
+ Attributes:
19
+ label: One of EXPLORING, EXPLOITING, STUCK, BROKEN.
20
+ trend_improving: True when the second half of the recent window is
21
+ better (lower BPB) than the first half.
22
+ experiment_diversity: Rough 0–1 score based on unique description
23
+ prefixes in the recent window.
24
+ crash_rate: Fraction of recent experiments that crashed.
25
+ best_bpb: Lowest val_bpb seen across all experiments.
26
+ last_improvement_at: Ordinal of the experiment that set best_bpb.
27
+ total_experiments: Total rows in results.tsv (excluding header).
28
+ """
29
+
30
+ label: str
31
+ trend_improving: bool
32
+ experiment_diversity: float
33
+ crash_rate: float
34
+ best_bpb: float
35
+ last_improvement_at: int
36
+ total_experiments: int
37
+
38
+
39
+ def diagnose(
40
+ results_path: str,
41
+ window: int = 20,
42
+ stuck_threshold: int = 10,
43
+ crash_threshold: float = 0.5,
44
+ ) -> ResearchState:
45
+ """Diagnose current research state from results.tsv.
46
+
47
+ Args:
48
+ results_path: Path to the tab-separated results file.
49
+ window: Number of recent experiments to consider for trend/diversity.
50
+ stuck_threshold: Experiments without improvement before labelling STUCK.
51
+ crash_threshold: Crash fraction above which state becomes BROKEN.
52
+
53
+ Returns:
54
+ ResearchState with diagnosis label and supporting statistics.
55
+ """
56
+ if not os.path.exists(results_path):
57
+ return ResearchState(
58
+ label="EXPLORING",
59
+ trend_improving=False,
60
+ experiment_diversity=0.0,
61
+ crash_rate=0.0,
62
+ best_bpb=float("inf"),
63
+ last_improvement_at=0,
64
+ total_experiments=0,
65
+ )
66
+
67
+ rows: list[dict] = []
68
+ with open(results_path) as fh:
69
+ reader = csv.DictReader(fh, delimiter="\t")
70
+ for row in reader:
71
+ rows.append(row)
72
+
73
+ if not rows:
74
+ return ResearchState(
75
+ label="EXPLORING",
76
+ trend_improving=False,
77
+ experiment_diversity=0.0,
78
+ crash_rate=0.0,
79
+ best_bpb=float("inf"),
80
+ last_improvement_at=0,
81
+ total_experiments=0,
82
+ )
83
+
84
+ total = len(rows)
85
+ recent = rows[-window:]
86
+
87
+ # Crash rate in the recent window.
88
+ crashes = sum(1 for r in recent if r.get("status") == "crash")
89
+ crash_rate = crashes / len(recent) if recent else 0.0
90
+
91
+ # Best BPB overall and which experiment achieved it.
92
+ best_bpb = float("inf")
93
+ last_improvement_at = 0
94
+ for i, row in enumerate(rows):
95
+ try:
96
+ bpb = float(row.get("val_bpb", "0") or "0")
97
+ except ValueError:
98
+ continue
99
+ if bpb > 0 and bpb < best_bpb:
100
+ best_bpb = bpb
101
+ last_improvement_at = i + 1
102
+
103
+ # Trend: is the second half of the recent window better than the first?
104
+ valid_bpbs = [
105
+ float(r.get("val_bpb", "0") or "0")
106
+ for r in recent
107
+ if float(r.get("val_bpb", "0") or "0") > 0
108
+ ]
109
+ trend_improving = False
110
+ if len(valid_bpbs) >= 4:
111
+ mid = len(valid_bpbs) // 2
112
+ first_half_mean = sum(valid_bpbs[:mid]) / mid
113
+ second_half_mean = sum(valid_bpbs[mid:]) / (len(valid_bpbs) - mid)
114
+ trend_improving = second_half_mean < first_half_mean
115
+
116
+ # Diversity: fraction of unique description prefixes (first 20 chars).
117
+ descriptions = {r.get("description", "")[:20] for r in recent}
118
+ diversity = min(1.0, len(descriptions) / max(1, len(recent)))
119
+
120
+ # Classify state.
121
+ stale = total - last_improvement_at
122
+ if crash_rate > crash_threshold:
123
+ label = "BROKEN"
124
+ elif stale >= stuck_threshold:
125
+ label = "STUCK"
126
+ elif trend_improving and diversity > 0.3:
127
+ label = "EXPLORING"
128
+ else:
129
+ label = "EXPLOITING"
130
+
131
+ return ResearchState(
132
+ label=label,
133
+ trend_improving=trend_improving,
134
+ experiment_diversity=diversity,
135
+ crash_rate=crash_rate,
136
+ best_bpb=best_bpb,
137
+ last_improvement_at=last_improvement_at,
138
+ total_experiments=total,
139
+ )
140
+
141
+
142
+ def should_explore(results_path: str, n: int = 10) -> bool:
143
+ """Return True when no improvement has been seen in the last N experiments.
144
+
145
+ Args:
146
+ results_path: Path to results.tsv.
147
+ n: Look-back window for improvement check.
148
+
149
+ Returns:
150
+ True if the research loop should try bolder mutations.
151
+ """
152
+ state = diagnose(results_path, window=n, stuck_threshold=n)
153
+ return state.label in ("STUCK", "BROKEN")
overlay/htm_rust/Cargo.lock CHANGED
@@ -1,383 +1,383 @@
1
- # This file is automatically @generated by Cargo.
2
- # It is not intended for manual editing.
3
- version = 4
4
-
5
- [[package]]
6
- name = "autocfg"
7
- version = "1.5.0"
8
- source = "registry+https://github.com/rust-lang/crates.io-index"
9
- checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
-
11
- [[package]]
12
- name = "cfg-if"
13
- version = "1.0.4"
14
- source = "registry+https://github.com/rust-lang/crates.io-index"
15
- checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
16
-
17
- [[package]]
18
- name = "cudarc"
19
- version = "0.12.1"
20
- source = "registry+https://github.com/rust-lang/crates.io-index"
21
- checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
22
- dependencies = [
23
- "libloading",
24
- ]
25
-
26
- [[package]]
27
- name = "getrandom"
28
- version = "0.2.17"
29
- source = "registry+https://github.com/rust-lang/crates.io-index"
30
- checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
31
- dependencies = [
32
- "cfg-if",
33
- "libc",
34
- "wasi",
35
- ]
36
-
37
- [[package]]
38
- name = "heck"
39
- version = "0.5.0"
40
- source = "registry+https://github.com/rust-lang/crates.io-index"
41
- checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
42
-
43
- [[package]]
44
- name = "htm_rust"
45
- version = "0.1.0"
46
- dependencies = [
47
- "cudarc",
48
- "ndarray",
49
- "numpy",
50
- "pyo3",
51
- "rand",
52
- "rand_xoshiro",
53
- ]
54
-
55
- [[package]]
56
- name = "indoc"
57
- version = "2.0.7"
58
- source = "registry+https://github.com/rust-lang/crates.io-index"
59
- checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
60
- dependencies = [
61
- "rustversion",
62
- ]
63
-
64
- [[package]]
65
- name = "libc"
66
- version = "0.2.185"
67
- source = "registry+https://github.com/rust-lang/crates.io-index"
68
- checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f"
69
-
70
- [[package]]
71
- name = "libloading"
72
- version = "0.8.9"
73
- source = "registry+https://github.com/rust-lang/crates.io-index"
74
- checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
75
- dependencies = [
76
- "cfg-if",
77
- "windows-link",
78
- ]
79
-
80
- [[package]]
81
- name = "matrixmultiply"
82
- version = "0.3.10"
83
- source = "registry+https://github.com/rust-lang/crates.io-index"
84
- checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
85
- dependencies = [
86
- "autocfg",
87
- "rawpointer",
88
- ]
89
-
90
- [[package]]
91
- name = "memoffset"
92
- version = "0.9.1"
93
- source = "registry+https://github.com/rust-lang/crates.io-index"
94
- checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
95
- dependencies = [
96
- "autocfg",
97
- ]
98
-
99
- [[package]]
100
- name = "ndarray"
101
- version = "0.16.1"
102
- source = "registry+https://github.com/rust-lang/crates.io-index"
103
- checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
104
- dependencies = [
105
- "matrixmultiply",
106
- "num-complex",
107
- "num-integer",
108
- "num-traits",
109
- "portable-atomic",
110
- "portable-atomic-util",
111
- "rawpointer",
112
- ]
113
-
114
- [[package]]
115
- name = "num-complex"
116
- version = "0.4.6"
117
- source = "registry+https://github.com/rust-lang/crates.io-index"
118
- checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
119
- dependencies = [
120
- "num-traits",
121
- ]
122
-
123
- [[package]]
124
- name = "num-integer"
125
- version = "0.1.46"
126
- source = "registry+https://github.com/rust-lang/crates.io-index"
127
- checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
128
- dependencies = [
129
- "num-traits",
130
- ]
131
-
132
- [[package]]
133
- name = "num-traits"
134
- version = "0.2.19"
135
- source = "registry+https://github.com/rust-lang/crates.io-index"
136
- checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
137
- dependencies = [
138
- "autocfg",
139
- ]
140
-
141
- [[package]]
142
- name = "numpy"
143
- version = "0.22.1"
144
- source = "registry+https://github.com/rust-lang/crates.io-index"
145
- checksum = "edb929bc0da91a4d85ed6c0a84deaa53d411abfb387fc271124f91bf6b89f14e"
146
- dependencies = [
147
- "libc",
148
- "ndarray",
149
- "num-complex",
150
- "num-integer",
151
- "num-traits",
152
- "pyo3",
153
- "rustc-hash",
154
- ]
155
-
156
- [[package]]
157
- name = "once_cell"
158
- version = "1.21.4"
159
- source = "registry+https://github.com/rust-lang/crates.io-index"
160
- checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
161
-
162
- [[package]]
163
- name = "portable-atomic"
164
- version = "1.13.1"
165
- source = "registry+https://github.com/rust-lang/crates.io-index"
166
- checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
167
-
168
- [[package]]
169
- name = "portable-atomic-util"
170
- version = "0.2.6"
171
- source = "registry+https://github.com/rust-lang/crates.io-index"
172
- checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
173
- dependencies = [
174
- "portable-atomic",
175
- ]
176
-
177
- [[package]]
178
- name = "ppv-lite86"
179
- version = "0.2.21"
180
- source = "registry+https://github.com/rust-lang/crates.io-index"
181
- checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
182
- dependencies = [
183
- "zerocopy",
184
- ]
185
-
186
- [[package]]
187
- name = "proc-macro2"
188
- version = "1.0.106"
189
- source = "registry+https://github.com/rust-lang/crates.io-index"
190
- checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
191
- dependencies = [
192
- "unicode-ident",
193
- ]
194
-
195
- [[package]]
196
- name = "pyo3"
197
- version = "0.22.6"
198
- source = "registry+https://github.com/rust-lang/crates.io-index"
199
- checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
200
- dependencies = [
201
- "cfg-if",
202
- "indoc",
203
- "libc",
204
- "memoffset",
205
- "once_cell",
206
- "portable-atomic",
207
- "pyo3-build-config",
208
- "pyo3-ffi",
209
- "pyo3-macros",
210
- "unindent",
211
- ]
212
-
213
- [[package]]
214
- name = "pyo3-build-config"
215
- version = "0.22.6"
216
- source = "registry+https://github.com/rust-lang/crates.io-index"
217
- checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
218
- dependencies = [
219
- "once_cell",
220
- "target-lexicon",
221
- ]
222
-
223
- [[package]]
224
- name = "pyo3-ffi"
225
- version = "0.22.6"
226
- source = "registry+https://github.com/rust-lang/crates.io-index"
227
- checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
228
- dependencies = [
229
- "libc",
230
- "pyo3-build-config",
231
- ]
232
-
233
- [[package]]
234
- name = "pyo3-macros"
235
- version = "0.22.6"
236
- source = "registry+https://github.com/rust-lang/crates.io-index"
237
- checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
238
- dependencies = [
239
- "proc-macro2",
240
- "pyo3-macros-backend",
241
- "quote",
242
- "syn",
243
- ]
244
-
245
- [[package]]
246
- name = "pyo3-macros-backend"
247
- version = "0.22.6"
248
- source = "registry+https://github.com/rust-lang/crates.io-index"
249
- checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
250
- dependencies = [
251
- "heck",
252
- "proc-macro2",
253
- "pyo3-build-config",
254
- "quote",
255
- "syn",
256
- ]
257
-
258
- [[package]]
259
- name = "quote"
260
- version = "1.0.45"
261
- source = "registry+https://github.com/rust-lang/crates.io-index"
262
- checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
263
- dependencies = [
264
- "proc-macro2",
265
- ]
266
-
267
- [[package]]
268
- name = "rand"
269
- version = "0.8.5"
270
- source = "registry+https://github.com/rust-lang/crates.io-index"
271
- checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
272
- dependencies = [
273
- "libc",
274
- "rand_chacha",
275
- "rand_core",
276
- ]
277
-
278
- [[package]]
279
- name = "rand_chacha"
280
- version = "0.3.1"
281
- source = "registry+https://github.com/rust-lang/crates.io-index"
282
- checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
283
- dependencies = [
284
- "ppv-lite86",
285
- "rand_core",
286
- ]
287
-
288
- [[package]]
289
- name = "rand_core"
290
- version = "0.6.4"
291
- source = "registry+https://github.com/rust-lang/crates.io-index"
292
- checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
293
- dependencies = [
294
- "getrandom",
295
- ]
296
-
297
- [[package]]
298
- name = "rand_xoshiro"
299
- version = "0.6.0"
300
- source = "registry+https://github.com/rust-lang/crates.io-index"
301
- checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
302
- dependencies = [
303
- "rand_core",
304
- ]
305
-
306
- [[package]]
307
- name = "rawpointer"
308
- version = "0.2.1"
309
- source = "registry+https://github.com/rust-lang/crates.io-index"
310
- checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
311
-
312
- [[package]]
313
- name = "rustc-hash"
314
- version = "1.1.0"
315
- source = "registry+https://github.com/rust-lang/crates.io-index"
316
- checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
317
-
318
- [[package]]
319
- name = "rustversion"
320
- version = "1.0.22"
321
- source = "registry+https://github.com/rust-lang/crates.io-index"
322
- checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
323
-
324
- [[package]]
325
- name = "syn"
326
- version = "2.0.117"
327
- source = "registry+https://github.com/rust-lang/crates.io-index"
328
- checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
329
- dependencies = [
330
- "proc-macro2",
331
- "quote",
332
- "unicode-ident",
333
- ]
334
-
335
- [[package]]
336
- name = "target-lexicon"
337
- version = "0.12.16"
338
- source = "registry+https://github.com/rust-lang/crates.io-index"
339
- checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
340
-
341
- [[package]]
342
- name = "unicode-ident"
343
- version = "1.0.24"
344
- source = "registry+https://github.com/rust-lang/crates.io-index"
345
- checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
346
-
347
- [[package]]
348
- name = "unindent"
349
- version = "0.2.4"
350
- source = "registry+https://github.com/rust-lang/crates.io-index"
351
- checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
352
-
353
- [[package]]
354
- name = "wasi"
355
- version = "0.11.1+wasi-snapshot-preview1"
356
- source = "registry+https://github.com/rust-lang/crates.io-index"
357
- checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
358
-
359
- [[package]]
360
- name = "windows-link"
361
- version = "0.2.1"
362
- source = "registry+https://github.com/rust-lang/crates.io-index"
363
- checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
364
-
365
- [[package]]
366
- name = "zerocopy"
367
- version = "0.8.48"
368
- source = "registry+https://github.com/rust-lang/crates.io-index"
369
- checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9"
370
- dependencies = [
371
- "zerocopy-derive",
372
- ]
373
-
374
- [[package]]
375
- name = "zerocopy-derive"
376
- version = "0.8.48"
377
- source = "registry+https://github.com/rust-lang/crates.io-index"
378
- checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4"
379
- dependencies = [
380
- "proc-macro2",
381
- "quote",
382
- "syn",
383
- ]
 
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "autocfg"
7
+ version = "1.5.0"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
+
11
+ [[package]]
12
+ name = "cfg-if"
13
+ version = "1.0.4"
14
+ source = "registry+https://github.com/rust-lang/crates.io-index"
15
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
16
+
17
+ [[package]]
18
+ name = "cudarc"
19
+ version = "0.12.1"
20
+ source = "registry+https://github.com/rust-lang/crates.io-index"
21
+ checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
22
+ dependencies = [
23
+ "libloading",
24
+ ]
25
+
26
+ [[package]]
27
+ name = "getrandom"
28
+ version = "0.2.17"
29
+ source = "registry+https://github.com/rust-lang/crates.io-index"
30
+ checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
31
+ dependencies = [
32
+ "cfg-if",
33
+ "libc",
34
+ "wasi",
35
+ ]
36
+
37
+ [[package]]
38
+ name = "heck"
39
+ version = "0.5.0"
40
+ source = "registry+https://github.com/rust-lang/crates.io-index"
41
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
42
+
43
+ [[package]]
44
+ name = "htm_rust"
45
+ version = "0.1.0"
46
+ dependencies = [
47
+ "cudarc",
48
+ "ndarray",
49
+ "numpy",
50
+ "pyo3",
51
+ "rand",
52
+ "rand_xoshiro",
53
+ ]
54
+
55
+ [[package]]
56
+ name = "indoc"
57
+ version = "2.0.7"
58
+ source = "registry+https://github.com/rust-lang/crates.io-index"
59
+ checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
60
+ dependencies = [
61
+ "rustversion",
62
+ ]
63
+
64
+ [[package]]
65
+ name = "libc"
66
+ version = "0.2.185"
67
+ source = "registry+https://github.com/rust-lang/crates.io-index"
68
+ checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f"
69
+
70
+ [[package]]
71
+ name = "libloading"
72
+ version = "0.8.9"
73
+ source = "registry+https://github.com/rust-lang/crates.io-index"
74
+ checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
75
+ dependencies = [
76
+ "cfg-if",
77
+ "windows-link",
78
+ ]
79
+
80
+ [[package]]
81
+ name = "matrixmultiply"
82
+ version = "0.3.10"
83
+ source = "registry+https://github.com/rust-lang/crates.io-index"
84
+ checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
85
+ dependencies = [
86
+ "autocfg",
87
+ "rawpointer",
88
+ ]
89
+
90
+ [[package]]
91
+ name = "memoffset"
92
+ version = "0.9.1"
93
+ source = "registry+https://github.com/rust-lang/crates.io-index"
94
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
95
+ dependencies = [
96
+ "autocfg",
97
+ ]
98
+
99
+ [[package]]
100
+ name = "ndarray"
101
+ version = "0.16.1"
102
+ source = "registry+https://github.com/rust-lang/crates.io-index"
103
+ checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
104
+ dependencies = [
105
+ "matrixmultiply",
106
+ "num-complex",
107
+ "num-integer",
108
+ "num-traits",
109
+ "portable-atomic",
110
+ "portable-atomic-util",
111
+ "rawpointer",
112
+ ]
113
+
114
+ [[package]]
115
+ name = "num-complex"
116
+ version = "0.4.6"
117
+ source = "registry+https://github.com/rust-lang/crates.io-index"
118
+ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
119
+ dependencies = [
120
+ "num-traits",
121
+ ]
122
+
123
+ [[package]]
124
+ name = "num-integer"
125
+ version = "0.1.46"
126
+ source = "registry+https://github.com/rust-lang/crates.io-index"
127
+ checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
128
+ dependencies = [
129
+ "num-traits",
130
+ ]
131
+
132
+ [[package]]
133
+ name = "num-traits"
134
+ version = "0.2.19"
135
+ source = "registry+https://github.com/rust-lang/crates.io-index"
136
+ checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
137
+ dependencies = [
138
+ "autocfg",
139
+ ]
140
+
141
+ [[package]]
142
+ name = "numpy"
143
+ version = "0.22.1"
144
+ source = "registry+https://github.com/rust-lang/crates.io-index"
145
+ checksum = "edb929bc0da91a4d85ed6c0a84deaa53d411abfb387fc271124f91bf6b89f14e"
146
+ dependencies = [
147
+ "libc",
148
+ "ndarray",
149
+ "num-complex",
150
+ "num-integer",
151
+ "num-traits",
152
+ "pyo3",
153
+ "rustc-hash",
154
+ ]
155
+
156
+ [[package]]
157
+ name = "once_cell"
158
+ version = "1.21.4"
159
+ source = "registry+https://github.com/rust-lang/crates.io-index"
160
+ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
161
+
162
+ [[package]]
163
+ name = "portable-atomic"
164
+ version = "1.13.1"
165
+ source = "registry+https://github.com/rust-lang/crates.io-index"
166
+ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
167
+
168
+ [[package]]
169
+ name = "portable-atomic-util"
170
+ version = "0.2.6"
171
+ source = "registry+https://github.com/rust-lang/crates.io-index"
172
+ checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
173
+ dependencies = [
174
+ "portable-atomic",
175
+ ]
176
+
177
+ [[package]]
178
+ name = "ppv-lite86"
179
+ version = "0.2.21"
180
+ source = "registry+https://github.com/rust-lang/crates.io-index"
181
+ checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
182
+ dependencies = [
183
+ "zerocopy",
184
+ ]
185
+
186
+ [[package]]
187
+ name = "proc-macro2"
188
+ version = "1.0.106"
189
+ source = "registry+https://github.com/rust-lang/crates.io-index"
190
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
191
+ dependencies = [
192
+ "unicode-ident",
193
+ ]
194
+
195
+ [[package]]
196
+ name = "pyo3"
197
+ version = "0.22.6"
198
+ source = "registry+https://github.com/rust-lang/crates.io-index"
199
+ checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
200
+ dependencies = [
201
+ "cfg-if",
202
+ "indoc",
203
+ "libc",
204
+ "memoffset",
205
+ "once_cell",
206
+ "portable-atomic",
207
+ "pyo3-build-config",
208
+ "pyo3-ffi",
209
+ "pyo3-macros",
210
+ "unindent",
211
+ ]
212
+
213
+ [[package]]
214
+ name = "pyo3-build-config"
215
+ version = "0.22.6"
216
+ source = "registry+https://github.com/rust-lang/crates.io-index"
217
+ checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
218
+ dependencies = [
219
+ "once_cell",
220
+ "target-lexicon",
221
+ ]
222
+
223
+ [[package]]
224
+ name = "pyo3-ffi"
225
+ version = "0.22.6"
226
+ source = "registry+https://github.com/rust-lang/crates.io-index"
227
+ checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
228
+ dependencies = [
229
+ "libc",
230
+ "pyo3-build-config",
231
+ ]
232
+
233
+ [[package]]
234
+ name = "pyo3-macros"
235
+ version = "0.22.6"
236
+ source = "registry+https://github.com/rust-lang/crates.io-index"
237
+ checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
238
+ dependencies = [
239
+ "proc-macro2",
240
+ "pyo3-macros-backend",
241
+ "quote",
242
+ "syn",
243
+ ]
244
+
245
+ [[package]]
246
+ name = "pyo3-macros-backend"
247
+ version = "0.22.6"
248
+ source = "registry+https://github.com/rust-lang/crates.io-index"
249
+ checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
250
+ dependencies = [
251
+ "heck",
252
+ "proc-macro2",
253
+ "pyo3-build-config",
254
+ "quote",
255
+ "syn",
256
+ ]
257
+
258
+ [[package]]
259
+ name = "quote"
260
+ version = "1.0.45"
261
+ source = "registry+https://github.com/rust-lang/crates.io-index"
262
+ checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
263
+ dependencies = [
264
+ "proc-macro2",
265
+ ]
266
+
267
+ [[package]]
268
+ name = "rand"
269
+ version = "0.8.5"
270
+ source = "registry+https://github.com/rust-lang/crates.io-index"
271
+ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
272
+ dependencies = [
273
+ "libc",
274
+ "rand_chacha",
275
+ "rand_core",
276
+ ]
277
+
278
+ [[package]]
279
+ name = "rand_chacha"
280
+ version = "0.3.1"
281
+ source = "registry+https://github.com/rust-lang/crates.io-index"
282
+ checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
283
+ dependencies = [
284
+ "ppv-lite86",
285
+ "rand_core",
286
+ ]
287
+
288
+ [[package]]
289
+ name = "rand_core"
290
+ version = "0.6.4"
291
+ source = "registry+https://github.com/rust-lang/crates.io-index"
292
+ checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
293
+ dependencies = [
294
+ "getrandom",
295
+ ]
296
+
297
+ [[package]]
298
+ name = "rand_xoshiro"
299
+ version = "0.6.0"
300
+ source = "registry+https://github.com/rust-lang/crates.io-index"
301
+ checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
302
+ dependencies = [
303
+ "rand_core",
304
+ ]
305
+
306
+ [[package]]
307
+ name = "rawpointer"
308
+ version = "0.2.1"
309
+ source = "registry+https://github.com/rust-lang/crates.io-index"
310
+ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
311
+
312
+ [[package]]
313
+ name = "rustc-hash"
314
+ version = "1.1.0"
315
+ source = "registry+https://github.com/rust-lang/crates.io-index"
316
+ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
317
+
318
+ [[package]]
319
+ name = "rustversion"
320
+ version = "1.0.22"
321
+ source = "registry+https://github.com/rust-lang/crates.io-index"
322
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
323
+
324
+ [[package]]
325
+ name = "syn"
326
+ version = "2.0.117"
327
+ source = "registry+https://github.com/rust-lang/crates.io-index"
328
+ checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
329
+ dependencies = [
330
+ "proc-macro2",
331
+ "quote",
332
+ "unicode-ident",
333
+ ]
334
+
335
+ [[package]]
336
+ name = "target-lexicon"
337
+ version = "0.12.16"
338
+ source = "registry+https://github.com/rust-lang/crates.io-index"
339
+ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
340
+
341
+ [[package]]
342
+ name = "unicode-ident"
343
+ version = "1.0.24"
344
+ source = "registry+https://github.com/rust-lang/crates.io-index"
345
+ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
346
+
347
+ [[package]]
348
+ name = "unindent"
349
+ version = "0.2.4"
350
+ source = "registry+https://github.com/rust-lang/crates.io-index"
351
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
352
+
353
+ [[package]]
354
+ name = "wasi"
355
+ version = "0.11.1+wasi-snapshot-preview1"
356
+ source = "registry+https://github.com/rust-lang/crates.io-index"
357
+ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
358
+
359
+ [[package]]
360
+ name = "windows-link"
361
+ version = "0.2.1"
362
+ source = "registry+https://github.com/rust-lang/crates.io-index"
363
+ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
364
+
365
+ [[package]]
366
+ name = "zerocopy"
367
+ version = "0.8.48"
368
+ source = "registry+https://github.com/rust-lang/crates.io-index"
369
+ checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9"
370
+ dependencies = [
371
+ "zerocopy-derive",
372
+ ]
373
+
374
+ [[package]]
375
+ name = "zerocopy-derive"
376
+ version = "0.8.48"
377
+ source = "registry+https://github.com/rust-lang/crates.io-index"
378
+ checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4"
379
+ dependencies = [
380
+ "proc-macro2",
381
+ "quote",
382
+ "syn",
383
+ ]
overlay/htm_rust/Cargo.toml CHANGED
@@ -1,37 +1,37 @@
1
- [package]
2
- name = "htm_rust"
3
- version = "0.1.0"
4
- edition = "2021"
5
- authors = ["Feather/HYDRA"]
6
- description = "Numenta BAMI-spec Hierarchical Temporal Memory (Spatial Pooler + Temporal Memory) with pyo3 bindings"
7
- license = "MIT"
8
-
9
- [lib]
10
- name = "htm_rust"
11
- crate-type = ["cdylib", "rlib"]
12
-
13
- [dependencies]
14
- pyo3 = { version = "0.22", features = ["extension-module"] }
15
- numpy = "0.22"
16
- ndarray = "0.16"
17
- rand = "0.8"
18
- rand_xoshiro = "0.6"
19
- # cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
20
- # Kernels are embedded as PTX and JIT-compiled at runtime.
21
- cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
22
-
23
- [build-dependencies]
24
- # Only required when building with --features gpu. We shell to nvcc directly
25
- # so we don't need cc's cuda support (which drags in extra deps).
26
-
27
- [features]
28
- default = []
29
- # `gpu` adds the HTMRegionGPU class, compiles .cu kernels to PTX at build time,
30
- # and links cudarc. Without this feature the crate is pure-CPU and has no
31
- # CUDA dependency at build or run time.
32
- gpu = ["cudarc"]
33
-
34
- [profile.release]
35
- opt-level = 3
36
- lto = "thin"
37
- codegen-units = 1
 
1
+ [package]
2
+ name = "htm_rust"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+ authors = ["Feather/HYDRA"]
6
+ description = "Numenta BAMI-spec Hierarchical Temporal Memory (Spatial Pooler + Temporal Memory) with pyo3 bindings"
7
+ license = "MIT"
8
+
9
+ [lib]
10
+ name = "htm_rust"
11
+ crate-type = ["cdylib", "rlib"]
12
+
13
+ [dependencies]
14
+ pyo3 = { version = "0.22", features = ["extension-module"] }
15
+ numpy = "0.22"
16
+ ndarray = "0.16"
17
+ rand = "0.8"
18
+ rand_xoshiro = "0.6"
19
+ # cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
20
+ # Kernels are embedded as PTX and JIT-compiled at runtime.
21
+ cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
22
+
23
+ [build-dependencies]
24
+ # Only required when building with --features gpu. We shell to nvcc directly
25
+ # so we don't need cc's cuda support (which drags in extra deps).
26
+
27
+ [features]
28
+ default = []
29
+ # `gpu` adds the HTMRegionGPU class, compiles .cu kernels to PTX at build time,
30
+ # and links cudarc. Without this feature the crate is pure-CPU and has no
31
+ # CUDA dependency at build or run time.
32
+ gpu = ["cudarc"]
33
+
34
+ [profile.release]
35
+ opt-level = 3
36
+ lto = "thin"
37
+ codegen-units = 1
overlay/htm_rust/build.rs CHANGED
@@ -1,160 +1,168 @@
1
- //! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
2
- //! is enabled. PTX files are embedded into the final Rust binary via
3
- //! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
4
- //!
5
- //! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
6
- //! toolchain dependency.
7
- //!
8
- //! nvcc lookup order:
9
- //! 1. $NVCC env var
10
- //! 2. `nvcc` on PATH
11
- //! 3. `/usr/local/cuda-12.1/bin/nvcc`
12
- //! 4. `/usr/local/cuda/bin/nvcc`
13
- //!
14
- //! Target: sm_90a (Hopper, H200 enables cluster::sync, TMA, wgmma). Override with $HTM_CUDA_ARCH.
15
-
16
- use std::env;
17
- use std::path::PathBuf;
18
- use std::process::Command;
19
-
20
- fn main() {
21
- // Re-run whenever we edit the build script or any kernel source.
22
- println!("cargo:rerun-if-changed=build.rs");
23
-
24
- let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
25
- if !gpu {
26
- return;
27
- }
28
-
29
- let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
30
- let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_90a".into());
31
-
32
- // Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
33
- let base_kernels: &[&str] = &[
34
- "sp_overlap",
35
- "sp_topk",
36
- "sp_learn",
37
- "sp_duty",
38
- "sp_boost_fused",
39
- "tm_predict",
40
- "tm_activate",
41
- "tm_learn",
42
- "tm_punish",
43
- "tm_grow",
44
- "tm_anomaly",
45
- "tm_reset",
46
- ];
47
-
48
- // htm_fused_step now compiles for ALL architectures (sm_80+).
49
- // On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
50
- // On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
51
- // with grid.sync() for cross-block synchronization (cooperative launch).
52
- let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
53
-
54
- let kernels_dir = PathBuf::from("src/gpu/kernels");
55
- for k in &kernels {
56
- let src = kernels_dir.join(format!("{k}.cu"));
57
- println!("cargo:rerun-if-changed={}", src.display());
58
- }
59
-
60
-
61
- let nvcc = find_nvcc();
62
- println!("cargo:warning=htm_rust: nvcc = {nvcc}");
63
- println!("cargo:warning=htm_rust: target arch = {arch}");
64
-
65
- // Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
66
- let host_compiler = env::var("HTM_CUDA_CCBIN")
67
- .ok()
68
- .or_else(|| {
69
- for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
70
- if std::path::Path::new(cand).exists() {
71
- return Some(cand.to_string());
72
- }
73
- }
74
- None
75
- });
76
-
77
- // Optionally patch the emitted PTX `.version` header down to match an
78
- // older driver. Useful when the system driver (e.g. on WSL2) is older
79
- // than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
80
- let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
81
-
82
- for k in kernels {
83
- let src = kernels_dir.join(format!("{k}.cu"));
84
- let ptx = out_dir.join(format!("{k}.ptx"));
85
- if !src.exists() {
86
- panic!("missing kernel source: {}", src.display());
87
- }
88
- let mut cmd = Command::new(&nvcc);
89
- // Note: `--use_fast_math` breaks bit-parity with host `expf`, which
90
- // in turn flips boost tie-breaks in SP learning. We accept the tiny
91
- // perf loss for correctness; the hot overlap kernel has no transcendentals.
92
- cmd.args([
93
- "--ptx",
94
- "-O3",
95
- "-rdc=true",
96
- "-arch",
97
- &arch,
98
- ]);
99
- if let Some(cc) = &host_compiler {
100
- cmd.args(["-ccbin", cc]);
101
- }
102
- cmd.arg("-o").arg(&ptx).arg(&src);
103
- let status = cmd
104
- .status()
105
- .unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
106
- if !status.success() {
107
- panic!("nvcc failed for {}", src.display());
108
- }
109
-
110
- if let Some(ver) = &ptx_version_override {
111
- // Read, patch, write.
112
- let text = std::fs::read_to_string(&ptx)
113
- .unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
114
- // Match `.version X.Y` where X and Y are digits. Replace whole line.
115
- let patched: String = text
116
- .lines()
117
- .map(|line| {
118
- let t = line.trim_start();
119
- if t.starts_with(".version ") {
120
- format!(".version {ver}")
121
- } else {
122
- line.to_string()
123
- }
124
- })
125
- .collect::<Vec<_>>()
126
- .join("\n");
127
- std::fs::write(&ptx, patched)
128
- .unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
129
- }
130
- }
131
-
132
- // Export OUT_DIR for include_str! in Rust.
133
- println!(
134
- "cargo:rustc-env=HTM_GPU_PTX_DIR={}",
135
- out_dir.display()
136
- );
137
- }
138
-
139
- fn find_nvcc() -> String {
140
- if let Ok(n) = env::var("NVCC") {
141
- return n;
142
- }
143
- // Try PATH.
144
- if Command::new("nvcc").arg("--version").output().is_ok() {
145
- return "nvcc".into();
146
- }
147
- for cand in [
148
- "/usr/local/cuda-12.1/bin/nvcc",
149
- "/usr/local/cuda/bin/nvcc",
150
- "/usr/local/cuda-12/bin/nvcc",
151
- ] {
152
- if std::path::Path::new(cand).exists() {
153
- return cand.into();
154
- }
155
- }
156
- panic!(
157
- "nvcc not found. Set $NVCC or install CUDA toolkit. \
158
- Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
159
- );
160
- }
 
 
 
 
 
 
 
 
 
1
+ //! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
2
+ //! is enabled. PTX files are embedded into the final Rust binary via
3
+ //! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
4
+ //!
5
+ //! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
6
+ //! toolchain dependency.
7
+ //!
8
+ //! nvcc lookup order:
9
+ //! 1. $NVCC env var
10
+ //! 2. `nvcc` on PATH
11
+ //! 3. `/usr/local/cuda-12.1/bin/nvcc`
12
+ //! 4. `/usr/local/cuda/bin/nvcc`
13
+ //!
14
+ //! Default target: sm_86 (Ampere A10G / RTX 30xx). Override with $HTM_CUDA_ARCH (e.g. sm_90a for H200).
15
+
16
+ use std::env;
17
+ use std::path::PathBuf;
18
+ use std::process::Command;
19
+
20
+ fn main() {
21
+ // Re-run whenever we edit the build script or any kernel source.
22
+ println!("cargo:rerun-if-changed=build.rs");
23
+
24
+ let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
25
+ if !gpu {
26
+ return;
27
+ }
28
+
29
+ let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
30
+ let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
31
+
32
+ // Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
33
+ let base_kernels: &[&str] = &[
34
+ "sp_overlap",
35
+ "sp_topk",
36
+ "sp_learn",
37
+ "sp_duty",
38
+ "sp_boost_fused",
39
+ "tm_predict",
40
+ "tm_activate",
41
+ "tm_learn",
42
+ "tm_punish",
43
+ "tm_grow",
44
+ "tm_anomaly",
45
+ "tm_reset",
46
+ ];
47
+
48
+ // htm_fused_step now compiles for ALL architectures (sm_80+).
49
+ // On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
50
+ // On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
51
+ // with grid.sync() for cross-block synchronization (cooperative launch).
52
+ let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
53
+
54
+ let kernels_dir = PathBuf::from("src/gpu/kernels");
55
+ for k in &kernels {
56
+ let src = kernels_dir.join(format!("{k}.cu"));
57
+ println!("cargo:rerun-if-changed={}", src.display());
58
+ }
59
+
60
+
61
+ let nvcc = find_nvcc();
62
+ println!("cargo:warning=htm_rust: nvcc = {nvcc}");
63
+ println!("cargo:warning=htm_rust: target arch = {arch}");
64
+
65
+ // Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
66
+ let host_compiler = env::var("HTM_CUDA_CCBIN")
67
+ .ok()
68
+ .or_else(|| {
69
+ for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
70
+ if std::path::Path::new(cand).exists() {
71
+ return Some(cand.to_string());
72
+ }
73
+ }
74
+ None
75
+ });
76
+
77
+ // Optionally patch the emitted PTX `.version` header down to match an
78
+ // older driver. Useful when the system driver (e.g. on WSL2) is older
79
+ // than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
80
+ let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
81
+
82
+ for k in kernels {
83
+ let src = kernels_dir.join(format!("{k}.cu"));
84
+ let ptx = out_dir.join(format!("{k}.ptx"));
85
+ if !src.exists() {
86
+ panic!("missing kernel source: {}", src.display());
87
+ }
88
+ let mut cmd = Command::new(&nvcc);
89
+ // Note: `--use_fast_math` breaks bit-parity with host `expf`, which
90
+ // in turn flips boost tie-breaks in SP learning. We accept the tiny
91
+ // perf loss for correctness; the hot overlap kernel has no transcendentals.
92
+ cmd.args([
93
+ "--ptx",
94
+ "-O3",
95
+ "-rdc=true",
96
+ "-arch",
97
+ &arch,
98
+ ]);
99
+ // `cooperative_groups::this_cluster()` is not declared for Ampere
100
+ // device compiles in CUDA 12.x, even if guarded by __CUDA_ARCH__ in
101
+ // some nvcc front-end phases. Define an explicit build-time kill
102
+ // switch for all non-Hopper targets so sm_86/A10G only sees the
103
+ // cooperative-grid path.
104
+ if !arch.starts_with("sm_90") {
105
+ cmd.arg("-DHTM_DISABLE_CLUSTER=1");
106
+ }
107
+ if let Some(cc) = &host_compiler {
108
+ cmd.args(["-ccbin", cc]);
109
+ }
110
+ cmd.arg("-o").arg(&ptx).arg(&src);
111
+ let status = cmd
112
+ .status()
113
+ .unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
114
+ if !status.success() {
115
+ panic!("nvcc failed for {}", src.display());
116
+ }
117
+
118
+ if let Some(ver) = &ptx_version_override {
119
+ // Read, patch, write.
120
+ let text = std::fs::read_to_string(&ptx)
121
+ .unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
122
+ // Match `.version X.Y` where X and Y are digits. Replace whole line.
123
+ let patched: String = text
124
+ .lines()
125
+ .map(|line| {
126
+ let t = line.trim_start();
127
+ if t.starts_with(".version ") {
128
+ format!(".version {ver}")
129
+ } else {
130
+ line.to_string()
131
+ }
132
+ })
133
+ .collect::<Vec<_>>()
134
+ .join("\n");
135
+ std::fs::write(&ptx, patched)
136
+ .unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
137
+ }
138
+ }
139
+
140
+ // Export OUT_DIR for include_str! in Rust.
141
+ println!(
142
+ "cargo:rustc-env=HTM_GPU_PTX_DIR={}",
143
+ out_dir.display()
144
+ );
145
+ }
146
+
147
+ fn find_nvcc() -> String {
148
+ if let Ok(n) = env::var("NVCC") {
149
+ return n;
150
+ }
151
+ // Try PATH.
152
+ if Command::new("nvcc").arg("--version").output().is_ok() {
153
+ return "nvcc".into();
154
+ }
155
+ for cand in [
156
+ "/usr/local/cuda-12.1/bin/nvcc",
157
+ "/usr/local/cuda/bin/nvcc",
158
+ "/usr/local/cuda-12/bin/nvcc",
159
+ ] {
160
+ if std::path::Path::new(cand).exists() {
161
+ return cand.into();
162
+ }
163
+ }
164
+ panic!(
165
+ "nvcc not found. Set $NVCC or install CUDA toolkit. \
166
+ Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
167
+ );
168
+ }
overlay/htm_rust/pyproject.toml CHANGED
@@ -1,17 +1,17 @@
1
- [build-system]
2
- requires = ["maturin>=1.4,<2.0"]
3
- build-backend = "maturin"
4
-
5
- [project]
6
- name = "htm_rust"
7
- version = "0.1.0"
8
- description = "Numenta BAMI-spec HTM (Spatial Pooler + Temporal Memory) in Rust with pyo3 bindings"
9
- requires-python = ">=3.11"
10
- classifiers = [
11
- "Programming Language :: Rust",
12
- "Programming Language :: Python :: Implementation :: CPython",
13
- ]
14
-
15
- [tool.maturin]
16
- features = ["pyo3/extension-module"]
17
- module-name = "htm_rust"
 
1
+ [build-system]
2
+ requires = ["maturin>=1.4,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "htm_rust"
7
+ version = "0.1.0"
8
+ description = "Numenta BAMI-spec HTM (Spatial Pooler + Temporal Memory) in Rust with pyo3 bindings"
9
+ requires-python = ">=3.11"
10
+ classifiers = [
11
+ "Programming Language :: Rust",
12
+ "Programming Language :: Python :: Implementation :: CPython",
13
+ ]
14
+
15
+ [tool.maturin]
16
+ features = ["pyo3/extension-module"]
17
+ module-name = "htm_rust"
overlay/htm_rust/src/gpu/fused.rs CHANGED
@@ -1,663 +1,702 @@
1
- //! Fused HTM megakernel launcher.
2
- //!
3
- //! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
4
- //! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
5
- //! the kernel design and the cross-block coherence strategy (grid barrier
6
- //! via device counter with all blocks concurrently resident).
7
- //!
8
- //! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
9
- //! probes the device SM count at construction and caps grid_dim.x
10
- //! accordingly — otherwise the grid barrier deadlocks.
11
- //!
12
- //! Semantic change from the top-K pipeline: activation is per-column
13
- //! threshold-based (local lateral inhibition) instead of global top-K.
14
- //! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
15
- //! the sparsity target. This is a real architectural change and is
16
- //! documented in `docs/GPU_HTM.md`.
17
-
18
- #![cfg(feature = "gpu")]
19
-
20
- use std::ffi::CString;
21
- use std::sync::Arc;
22
-
23
- use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
24
- LaunchConfig};
25
- use cudarc::nvrtc::Ptx;
26
-
27
- use super::sp_gpu::SpatialPoolerGpu;
28
- use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
29
-
30
- const PTX_HTM_FUSED: &str =
31
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
32
-
33
- /// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
34
- ///
35
- /// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
36
- /// C-side `FusedPtrs` still has the field at the same byte offset; removing
37
- /// it here would shift all subsequent fields and break the layout. Worker A
38
- /// will eventually delete the field from both sides once the kernel is
39
- /// updated; until then we zero it.
40
- #[repr(C)]
41
- #[derive(Clone, Copy)]
42
- pub struct FusedPtrs {
43
- pub syn_bit: u64,
44
- pub syn_perm: u64,
45
- pub boost: u64,
46
- pub active_duty: u64,
47
- pub inhibition_threshold: u64,
48
- pub seg_cell_id: u64,
49
- pub seg_syn_count: u64,
50
- pub syn_presyn: u64,
51
- pub tm_syn_perm: u64,
52
- pub cell_seg_count: u64,
53
- pub cell_active_a: u64,
54
- pub cell_active_b: u64,
55
- pub cell_winner_a: u64,
56
- pub cell_winner_b: u64,
57
- pub inputs: u64,
58
- pub cols_out: u64,
59
- pub anom_out: u64,
60
- /// ABI-compat dummy — always 0. No device memory is allocated for this
61
- /// field; the cluster barrier replaces the old software DLB barrier.
62
- pub barrier_counters: u64,
63
- pub step_scratch: u64,
64
- }
65
-
66
- unsafe impl DeviceRepr for FusedPtrs {}
67
-
68
- /// Launch-time config — matches C-side `FusedConfig` 1:1.
69
- #[repr(C)]
70
- #[derive(Clone, Copy)]
71
- pub struct FusedConfig {
72
- pub input_bits: u32,
73
- pub n_columns: u32,
74
- pub synapses_per_col: u32,
75
- pub conn_thr: f32,
76
- pub sp_inc: f32,
77
- pub sp_dec: f32,
78
- pub sparsity_target: f32,
79
- pub duty_alpha: f32,
80
- pub thr_adapt_rate: f32,
81
- pub cells_per_column: u32,
82
- pub n_cells: u32,
83
- pub bits_words: u32,
84
- pub max_segments_per_cell: u32,
85
- pub synapses_per_segment: u32,
86
- pub activation_threshold: u32,
87
- pub learning_threshold: u32,
88
- pub max_new_synapses: u32,
89
- pub conn_thr_i16: i32,
90
- pub perm_inc_i16: i32,
91
- pub perm_dec_i16: i32,
92
- pub predicted_seg_dec_i16: i32,
93
- pub initial_perm_i16: i32,
94
- pub t: u32,
95
- pub learn: u32,
96
- pub iter_seed: u32,
97
- pub cooperative_grid_sync: u32,
98
- }
99
-
100
- unsafe impl DeviceRepr for FusedConfig {}
101
-
102
- /// Cluster launch parameters probed at construction time.
103
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
104
- pub(crate) struct ClusterInfo {
105
- /// Maximum cluster size supported by this device (0 = cluster unsupported).
106
- pub max_cluster_size: u32,
107
- }
108
-
109
- // There is only ONE launch mode: non-cooperative launch with Hopper Thread
110
- // Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
111
- // software DLB barrier and the cooperative-launch path are both removed.
112
- // Cluster barriers replace both.
113
- #[derive(Clone, Copy, Debug, PartialEq, Eq)]
114
- pub(crate) struct FusedLaunchPlan {
115
- pub grid_dim_x: u32,
116
- pub block_dim_x: u32,
117
- pub cooperative_grid_limit: u32,
118
- pub sm_count: u32,
119
- }
120
-
121
- fn fused_grid_cap_override() -> Option<u32> {
122
- std::env::var("HTM_FUSED_GRID_CAP")
123
- .ok()
124
- .and_then(|s| s.parse::<u32>().ok())
125
- .map(|v| v.max(1))
126
- }
127
-
128
- pub(crate) fn plan_fused_launch(
129
- sm_count: u32,
130
- cooperative_supported: bool,
131
- cooperative_grid_limit: u32,
132
- grid_cap_override: Option<u32>,
133
- ) -> Result<FusedLaunchPlan, String> {
134
- let sm_count = sm_count.max(1);
135
- // 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
136
- // regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
137
- // 256 regs/thread which is ample. Compensate with more blocks via
138
- // cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
139
- // 1024 works fine, but 256 is safe everywhere.
140
- let block_dim_x = 256u32;
141
-
142
- // Cluster launch path: cooperative launch is not required. Keep the probe
143
- // result for residency estimation only.
144
- if !cooperative_supported {
145
- eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
146
- }
147
-
148
- // Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
149
- // Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
150
- let default_grid_cap = 16u32;
151
- let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
152
- let resident_bound = if cooperative_grid_limit > 0 {
153
- cooperative_grid_limit.max(sm_count * 2)
154
- } else {
155
- sm_count * 2
156
- };
157
- Ok(FusedLaunchPlan {
158
- grid_dim_x: resident_bound.min(grid_cap).max(1),
159
- block_dim_x,
160
- cooperative_grid_limit: resident_bound,
161
- sm_count,
162
- })
163
- }
164
-
165
- pub(super) struct RawFusedKernel {
166
- module: sys::CUmodule,
167
- pub(super) function: sys::CUfunction,
168
- pub(super) function_batched: sys::CUfunction,
169
- }
170
-
171
- unsafe impl Send for RawFusedKernel {}
172
- unsafe impl Sync for RawFusedKernel {}
173
-
174
- impl Drop for RawFusedKernel {
175
- fn drop(&mut self) {
176
- unsafe {
177
- let _ = result::module::unload(self.module);
178
- }
179
- }
180
- }
181
-
182
- /// Owns fused-path-only device state:
183
- /// - per-column inhibition threshold (replaces global top-K)
184
- /// - ping-pong cell_active/cell_winner bitsets
185
- /// - step_scratch (n_active, n_unpred per timestep)
186
- /// - cluster launch capability info
187
- pub struct FusedState {
188
- dev: Arc<CudaDevice>,
189
- pub(super) raw_kernel: RawFusedKernel,
190
-
191
- pub inhibition_threshold: CudaSlice<f32>,
192
- pub cell_active_bits_a: CudaSlice<u32>,
193
- pub cell_active_bits_b: CudaSlice<u32>,
194
- pub cell_winner_bits_a: CudaSlice<u32>,
195
- pub cell_winner_bits_b: CudaSlice<u32>,
196
- pub step_scratch: CudaSlice<u32>, // length 6
197
-
198
- pub grid_dim_x: u32,
199
- pub block_dim_x: u32,
200
- pub cooperative_grid_limit: u32,
201
- pub iter_counter: u32,
202
-
203
- /// Hopper cluster launch capability (0 = unsupported).
204
- pub cluster_info: ClusterInfo,
205
-
206
- // Config mirror (read-only after init).
207
- #[allow(dead_code)]
208
- pub initial_threshold: f32,
209
- }
210
-
211
- impl FusedState {
212
- pub fn new(
213
- dev: Arc<CudaDevice>,
214
- n_columns: usize,
215
- cells_per_column: usize,
216
- initial_threshold: f32,
217
- ) -> Result<Self, DriverError> {
218
- let n_cells = n_columns * cells_per_column;
219
- assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
220
- let bits_words = n_cells / 32;
221
-
222
- let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
223
- let init_vec = vec![initial_threshold; n_columns];
224
- dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?;
225
-
226
- let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
227
- let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
228
- let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
229
- let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
230
- let step_scratch = dev.alloc_zeros::<u32>(6)?;
231
-
232
- unsafe {
233
- result::ctx::set_current(*dev.cu_primary_ctx())?;
234
- }
235
- if dev.get_func("htm_fused", "htm_fused_step").is_none() {
236
- dev.load_ptx(
237
- Ptx::from_src(PTX_HTM_FUSED),
238
- "htm_fused",
239
- &["htm_fused_step", "htm_fused_step_batched"],
240
- )?;
241
- }
242
- let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes");
243
- let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?;
244
- let function = unsafe {
245
- result::module::get_function(module, CString::new("htm_fused_step").unwrap())
246
- }?;
247
- let function_batched = unsafe {
248
- result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap())
249
- }?;
250
-
251
- // Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in).
252
- // Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on
253
- // every launched kernel function, otherwise cuLaunchKernelEx rejects
254
- // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
255
- unsafe {
256
- let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
257
- // Ignore errors: older CUDA may lack the attribute, in which case
258
- // only portable sizes (<= 8) work — plan_fused_launch caps at 8.
259
- let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
260
- let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1);
261
- }
262
-
263
- // Probe SM count.
264
- let sm_count = match dev.attribute(
265
- cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
266
- ) {
267
- Ok(v) => v as u32,
268
- Err(_) => 16u32,
269
- };
270
-
271
- // T1: Probe Hopper cluster launch capability.
272
- let max_cluster_size = match dev.attribute(
273
- cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
274
- ) {
275
- Ok(v) if v > 0 => {
276
- // H200/sm_90a supports up to 16 blocks per cluster.
277
- // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
278
- // Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster).
279
- 16u32
280
- }
281
- _ => 0u32,
282
- };
283
- eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
284
- let cluster_info = ClusterInfo { max_cluster_size };
285
-
286
- let cooperative_supported = matches!(
287
- dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
288
- Ok(v) if v > 0
289
- );
290
- let cooperative_grid_limit = if cooperative_supported {
291
- let blocks_per_sm = unsafe {
292
- result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0)
293
- }
294
- .ok()
295
- .map(|v| v.max(0) as u32)
296
- .unwrap_or(0);
297
- sm_count.saturating_mul(blocks_per_sm)
298
- } else {
299
- 0
300
- };
301
- let launch_plan = plan_fused_launch(
302
- sm_count,
303
- cooperative_supported,
304
- cooperative_grid_limit,
305
- fused_grid_cap_override(),
306
- )
307
- .map_err(|msg| {
308
- // Surface as a CUDA-ish error so callers can propagate.
309
- eprintln!("[htm_rust] FATAL: {msg}");
310
- DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
311
- })?;
312
-
313
- eprintln!(
314
- "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
315
- launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
316
- cluster_info.max_cluster_size,
317
- );
318
-
319
- Ok(Self {
320
- dev,
321
- raw_kernel: RawFusedKernel { module, function, function_batched },
322
- inhibition_threshold,
323
- cell_active_bits_a,
324
- cell_active_bits_b,
325
- cell_winner_bits_a,
326
- cell_winner_bits_b,
327
- step_scratch,
328
- grid_dim_x: launch_plan.grid_dim_x,
329
- block_dim_x: launch_plan.block_dim_x,
330
- cooperative_grid_limit: launch_plan.cooperative_grid_limit,
331
- iter_counter: 0,
332
- cluster_info,
333
- initial_threshold,
334
- })
335
- }
336
-
337
- /// Reset fused state. Called at region.reset().
338
- pub fn reset(&mut self) -> Result<(), DriverError> {
339
- self.dev.memset_zeros(&mut self.cell_active_bits_a)?;
340
- self.dev.memset_zeros(&mut self.cell_active_bits_b)?;
341
- self.dev.memset_zeros(&mut self.cell_winner_bits_a)?;
342
- self.dev.memset_zeros(&mut self.cell_winner_bits_b)?;
343
- self.dev.memset_zeros(&mut self.step_scratch)?;
344
- // Do NOT reset inhibition_threshold — it's learned state. A hard
345
- // reset of TM state should NOT forget the sparsity calibration.
346
- Ok(())
347
- }
348
- }
349
-
350
- /// Launch the fused megakernel. Processes all T timesteps in one kernel.
351
- ///
352
- /// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)`
353
- /// when the device supports cluster launch, otherwise falls back to a plain
354
- /// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the
355
- /// entire grid fits in one cluster.
356
- #[allow(clippy::too_many_arguments)]
357
- pub fn launch_fused(
358
- sp: &mut SpatialPoolerGpu,
359
- tm: &mut TemporalMemoryGpu,
360
- fused: &mut FusedState,
361
- inputs_flat: &CudaSlice<u8>,
362
- cols_out: &mut CudaSlice<u8>,
363
- anom_out: &mut CudaSlice<f32>,
364
- t: usize,
365
- input_bits: usize,
366
- learn: bool,
367
- ) -> Result<(), DriverError> {
368
- // Reset step_scratch before each launch (safe re-entry).
369
- sp.dev_ref().memset_zeros(&mut fused.step_scratch)?;
370
-
371
- fused.iter_counter = fused.iter_counter.wrapping_add(1);
372
-
373
- let cfg = FusedConfig {
374
- input_bits: input_bits as u32,
375
- n_columns: sp.n_columns_accessor() as u32,
376
- synapses_per_col: sp.synapses_per_col_accessor() as u32,
377
- conn_thr: sp.conn_thr_accessor(),
378
- sp_inc: sp.inc_accessor(),
379
- sp_dec: sp.dec_accessor(),
380
- sparsity_target: sp.sparsity_accessor(),
381
- duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0),
382
- thr_adapt_rate: 0.001f32,
383
- cells_per_column: tm.cells_per_column as u32,
384
- n_cells: tm.n_cells as u32,
385
- bits_words: tm.bits_words as u32,
386
- max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
387
- synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
388
- activation_threshold: tm.activation_threshold,
389
- learning_threshold: tm.learning_threshold,
390
- max_new_synapses: tm.max_new_synapse_count,
391
- conn_thr_i16: tm.conn_thr_i16 as i32,
392
- perm_inc_i16: tm.perm_inc_i16 as i32,
393
- perm_dec_i16: tm.perm_dec_i16 as i32,
394
- predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32,
395
- initial_perm_i16: tm.initial_perm_i16 as i32,
396
- t: t as u32,
397
- learn: if learn { 1 } else { 0 },
398
- iter_seed: fused.iter_counter,
399
- cooperative_grid_sync: 1,
400
- };
401
-
402
- let ptrs = FusedPtrs {
403
- syn_bit: *sp.syn_bit_accessor().device_ptr(),
404
- syn_perm: *sp.syn_perm_accessor().device_ptr(),
405
- boost: *sp.boost_accessor().device_ptr(),
406
- active_duty: *sp.active_duty_accessor().device_ptr(),
407
- inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
408
- seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(),
409
- seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(),
410
- syn_presyn: *tm.syn_presyn_accessor().device_ptr(),
411
- tm_syn_perm: *tm.syn_perm_accessor().device_ptr(),
412
- cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(),
413
- cell_active_a: *fused.cell_active_bits_a.device_ptr(),
414
- cell_active_b: *fused.cell_active_bits_b.device_ptr(),
415
- cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
416
- cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
417
- inputs: *inputs_flat.device_ptr(),
418
- cols_out: *cols_out.device_ptr(),
419
- anom_out: *anom_out.device_ptr(),
420
- barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
421
- step_scratch: *fused.step_scratch.device_ptr(),
422
- };
423
-
424
- let grid_x = fused.grid_dim_x;
425
- let block_x = fused.block_dim_x;
426
- let cu_stream = *sp.dev_ref().cu_stream();
427
- let use_cluster = fused.cluster_info.max_cluster_size > 0;
428
-
429
- unsafe {
430
- result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?;
431
- let mut kernel_params: [*mut std::ffi::c_void; 2] = [
432
- (&ptrs as *const FusedPtrs).cast_mut().cast(),
433
- (&cfg as *const FusedConfig).cast_mut().cast(),
434
- ];
435
-
436
- if use_cluster {
437
- // T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION.
438
- // cluster_dim=(16,1,1) maps the entire single-region grid into one cluster.
439
- let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
440
- attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
441
- attr.value.clusterDim.x = 16;
442
- attr.value.clusterDim.y = 1;
443
- attr.value.clusterDim.z = 1;
444
-
445
- let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
446
- launch_cfg.gridDimX = grid_x;
447
- launch_cfg.gridDimY = 1;
448
- launch_cfg.gridDimZ = 1;
449
- launch_cfg.blockDimX = block_x;
450
- launch_cfg.blockDimY = 1;
451
- launch_cfg.blockDimZ = 1;
452
- launch_cfg.sharedMemBytes = 0;
453
- launch_cfg.hStream = cu_stream;
454
- launch_cfg.numAttrs = 1;
455
- launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
456
-
457
- let ret = sys::lib().cuLaunchKernelEx(
458
- &launch_cfg as *const sys::CUlaunchConfig,
459
- fused.raw_kernel.function,
460
- kernel_params.as_mut_ptr(),
461
- std::ptr::null_mut(),
462
- );
463
- if ret != sys::CUresult::CUDA_SUCCESS {
464
- return Err(DriverError(ret));
465
- }
466
- } else {
467
- // Pre-Hopper: cooperative kernel launch. The fused kernel uses
468
- // grid.sync() for cross-block synchronization which REQUIRES
469
- // cuLaunchCooperativeKernel (normal launch silently crashes on
470
- // the first grid.sync() call).
471
- let ret = sys::lib().cuLaunchCooperativeKernel(
472
- fused.raw_kernel.function,
473
- grid_x, 1, 1,
474
- block_x, 1, 1,
475
- 0, // sharedMemBytes
476
- cu_stream,
477
- kernel_params.as_mut_ptr(),
478
- );
479
- if ret != sys::CUresult::CUDA_SUCCESS {
480
- return Err(DriverError(ret));
481
- }
482
- }
483
- }
484
-
485
- Ok(())
486
- }
487
-
488
- /// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel
489
- /// body; each block reads its region's FusedPtrs from a device-side array
490
- /// indexed by blockIdx.y. All regions share the same config (same
491
- /// input_bits/n_columns/etc.) so we pass one FusedConfig.
492
- ///
493
- /// This breaks through the CUDA cooperative-kernel device-level
494
- /// serialization: multiple cooperative launches are serialized regardless
495
- /// of stream, but one cooperative launch with grid.y=B processes all
496
- /// regions in a single invocation — ~B× speedup vs B sequential launches.
497
- #[allow(clippy::too_many_arguments)]
498
- /// Low-level raw-pointer entry, called by PyO3 binding which holds the
499
- /// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live,
500
- /// uniquely-borrowed region. All regions must be distinct.
501
- pub(super) fn launch_fused_batched_raw(
502
- region_ptrs: &[*mut super::HTMRegionGpu],
503
- inputs_per_region: &[u64],
504
- cols_per_region: &[u64],
505
- anom_per_region: &[u64],
506
- t: usize,
507
- input_bits: usize,
508
- learn: bool,
509
- ) -> Result<(), DriverError> {
510
- let b = region_ptrs.len();
511
- assert_eq!(inputs_per_region.len(), b);
512
- assert_eq!(cols_per_region.len(), b);
513
- assert_eq!(anom_per_region.len(), b);
514
- assert!(b >= 1, "need at least one region");
515
-
516
- // Reset per-region step_scratch before each launch.
517
- for &rp in region_ptrs.iter() {
518
- let r = unsafe { &mut *rp };
519
- let dev = r.sp_gpu.dev_ref().clone();
520
- dev.memset_zeros(&mut r.fused_state.step_scratch)?;
521
- r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1);
522
- }
523
-
524
- // Shared configall regions use identical sp/tm parameters.
525
- let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
526
- let r0 = unsafe { &*region_ptrs[0] };
527
- (
528
- r0.fused_state.grid_dim_x,
529
- r0.fused_state.block_dim_x,
530
- r0.fused_state.raw_kernel.function_batched,
531
- *r0.sp_gpu.dev_ref().cu_stream(),
532
- *r0.sp_gpu.dev_ref().cu_primary_ctx(),
533
- )
534
- };
535
-
536
- let cfg = {
537
- let r = unsafe { &*region_ptrs[0] };
538
- FusedConfig {
539
- input_bits: input_bits as u32,
540
- n_columns: r.sp_gpu.n_columns_accessor() as u32,
541
- synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
542
- conn_thr: r.sp_gpu.conn_thr_accessor(),
543
- sp_inc: r.sp_gpu.inc_accessor(),
544
- sp_dec: r.sp_gpu.dec_accessor(),
545
- sparsity_target: r.sp_gpu.sparsity_accessor(),
546
- duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0),
547
- thr_adapt_rate: 0.001f32,
548
- cells_per_column: r.tm_gpu.cells_per_column as u32,
549
- n_cells: r.tm_gpu.n_cells as u32,
550
- bits_words: r.tm_gpu.bits_words as u32,
551
- max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
552
- synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
553
- activation_threshold: r.tm_gpu.activation_threshold,
554
- learning_threshold: r.tm_gpu.learning_threshold,
555
- max_new_synapses: r.tm_gpu.max_new_synapse_count,
556
- conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32,
557
- perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32,
558
- perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32,
559
- predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32,
560
- initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
561
- t: t as u32,
562
- learn: if learn { 1 } else { 0 },
563
- iter_seed: r.fused_state.iter_counter,
564
- cooperative_grid_sync: 1,
565
- }
566
- };
567
-
568
- // Build B FusedPtrs per-region.
569
- let ptrs_vec: Vec<FusedPtrs> = (0..b)
570
- .map(|i| {
571
- let r = unsafe { &*region_ptrs[i] };
572
- FusedPtrs {
573
- syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
574
- syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
575
- boost: *r.sp_gpu.boost_accessor().device_ptr(),
576
- active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
577
- inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(),
578
- seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
579
- seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
580
- syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
581
- tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
582
- cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
583
- cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(),
584
- cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(),
585
- cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(),
586
- cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(),
587
- inputs: inputs_per_region[i],
588
- cols_out: cols_per_region[i],
589
- anom_out: anom_per_region[i],
590
- barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
591
- step_scratch: *r.fused_state.step_scratch.device_ptr(),
592
- }
593
- })
594
- .collect();
595
-
596
- // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
597
- // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
598
- let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone();
599
- let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?;
600
- let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr();
601
-
602
- // T10: Cluster launch for batched regions.
603
- // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
604
- // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
605
- // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
606
- let use_cluster = {
607
- let r0 = unsafe { &*region_ptrs[0] };
608
- r0.fused_state.cluster_info.max_cluster_size > 0
609
- };
610
-
611
- unsafe {
612
- result::ctx::set_current(cu_ctx)?;
613
- let mut kernel_params: [*mut std::ffi::c_void; 2] = [
614
- (&ptrs_dev_ptr as *const u64).cast_mut().cast(),
615
- (&cfg as *const FusedConfig).cast_mut().cast(),
616
- ];
617
-
618
- if use_cluster {
619
- let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
620
- attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
621
- attr.value.clusterDim.x = 16;
622
- attr.value.clusterDim.y = 1;
623
- attr.value.clusterDim.z = 1;
624
-
625
- let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
626
- launch_cfg.gridDimX = grid_x;
627
- launch_cfg.gridDimY = b as u32;
628
- launch_cfg.gridDimZ = 1;
629
- launch_cfg.blockDimX = block_x;
630
- launch_cfg.blockDimY = 1;
631
- launch_cfg.blockDimZ = 1;
632
- launch_cfg.sharedMemBytes = 0;
633
- launch_cfg.hStream = cu_stream;
634
- launch_cfg.numAttrs = 1;
635
- launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
636
-
637
- let ret = sys::lib().cuLaunchKernelEx(
638
- &launch_cfg as *const sys::CUlaunchConfig,
639
- function_batched,
640
- kernel_params.as_mut_ptr(),
641
- std::ptr::null_mut(),
642
- );
643
- if ret != sys::CUresult::CUDA_SUCCESS {
644
- return Err(DriverError(ret));
645
- }
646
- } else {
647
- // Pre-Hopper: cooperative kernel launch (grid.sync() requires it).
648
- let ret = sys::lib().cuLaunchCooperativeKernel(
649
- function_batched,
650
- grid_x, b as u32, 1,
651
- block_x, 1, 1,
652
- 0, // sharedMemBytes
653
- cu_stream,
654
- kernel_params.as_mut_ptr(),
655
- );
656
- if ret != sys::CUresult::CUDA_SUCCESS {
657
- return Err(DriverError(ret));
658
- }
659
- }
660
- }
661
-
662
- Ok(())
663
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Fused HTM megakernel launcher.
2
+ //!
3
+ //! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
4
+ //! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
5
+ //! the kernel design and the cross-block coherence strategy (grid barrier
6
+ //! via device counter with all blocks concurrently resident).
7
+ //!
8
+ //! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
9
+ //! probes the device SM count at construction and caps grid_dim.x
10
+ //! accordingly — otherwise the grid barrier deadlocks.
11
+ //!
12
+ //! Semantic change from the top-K pipeline: activation is per-column
13
+ //! threshold-based (local lateral inhibition) instead of global top-K.
14
+ //! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
15
+ //! the sparsity target. This is a real architectural change and is
16
+ //! documented in `docs/GPU_HTM.md`.
17
+
18
+ #![cfg(feature = "gpu")]
19
+
20
+ use std::ffi::CString;
21
+ use std::sync::Arc;
22
+
23
+ use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
24
+ LaunchConfig};
25
+ use cudarc::nvrtc::Ptx;
26
+
27
+ use super::sp_gpu::SpatialPoolerGpu;
28
+ use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
29
+
30
+ const PTX_HTM_FUSED: &str =
31
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
32
+
33
+ /// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
34
+ ///
35
+ /// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
36
+ /// C-side `FusedPtrs` still has the field at the same byte offset; removing
37
+ /// it here would shift all subsequent fields and break the layout. Worker A
38
+ /// will eventually delete the field from both sides once the kernel is
39
+ /// updated; until then we zero it.
40
+ #[repr(C)]
41
+ #[derive(Clone, Copy)]
42
+ pub struct FusedPtrs {
43
+ pub syn_bit: u64,
44
+ pub syn_perm: u64,
45
+ pub boost: u64,
46
+ pub active_duty: u64,
47
+ pub inhibition_threshold: u64,
48
+ pub seg_cell_id: u64,
49
+ pub seg_syn_count: u64,
50
+ pub syn_presyn: u64,
51
+ pub tm_syn_perm: u64,
52
+ pub cell_seg_count: u64,
53
+ pub cell_active_a: u64,
54
+ pub cell_active_b: u64,
55
+ pub cell_winner_a: u64,
56
+ pub cell_winner_b: u64,
57
+ pub inputs: u64,
58
+ pub cols_out: u64,
59
+ pub anom_out: u64,
60
+ /// ABI-compat dummy — always 0. No device memory is allocated for this
61
+ /// field; the cluster barrier replaces the old software DLB barrier.
62
+ pub barrier_counters: u64,
63
+ pub step_scratch: u64,
64
+ }
65
+
66
+ unsafe impl DeviceRepr for FusedPtrs {}
67
+
68
+ /// Launch-time config — matches C-side `FusedConfig` 1:1.
69
+ #[repr(C)]
70
+ #[derive(Clone, Copy)]
71
+ pub struct FusedConfig {
72
+ pub input_bits: u32,
73
+ pub n_columns: u32,
74
+ pub synapses_per_col: u32,
75
+ pub conn_thr: f32,
76
+ pub sp_inc: f32,
77
+ pub sp_dec: f32,
78
+ pub sparsity_target: f32,
79
+ pub duty_alpha: f32,
80
+ pub thr_adapt_rate: f32,
81
+ pub cells_per_column: u32,
82
+ pub n_cells: u32,
83
+ pub bits_words: u32,
84
+ pub max_segments_per_cell: u32,
85
+ pub synapses_per_segment: u32,
86
+ pub activation_threshold: u32,
87
+ pub learning_threshold: u32,
88
+ pub max_new_synapses: u32,
89
+ pub conn_thr_i16: i32,
90
+ pub perm_inc_i16: i32,
91
+ pub perm_dec_i16: i32,
92
+ pub predicted_seg_dec_i16: i32,
93
+ pub initial_perm_i16: i32,
94
+ pub t: u32,
95
+ pub learn: u32,
96
+ pub iter_seed: u32,
97
+ pub cooperative_grid_sync: u32,
98
+ }
99
+
100
+ unsafe impl DeviceRepr for FusedConfig {}
101
+
102
+ /// Cluster launch parameters probed at construction time.
103
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
104
+ pub(crate) struct ClusterInfo {
105
+ /// Maximum cluster size supported by this device (0 = cluster unsupported).
106
+ pub max_cluster_size: u32,
107
+ }
108
+
109
+ // There is only ONE launch mode: non-cooperative launch with Hopper Thread
110
+ // Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
111
+ // software DLB barrier and the cooperative-launch path are both removed.
112
+ // Cluster barriers replace both.
113
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
114
+ pub(crate) struct FusedLaunchPlan {
115
+ pub grid_dim_x: u32,
116
+ pub block_dim_x: u32,
117
+ pub cooperative_grid_limit: u32,
118
+ pub sm_count: u32,
119
+ }
120
+
121
+ fn fused_grid_cap_override() -> Option<u32> {
122
+ std::env::var("HTM_FUSED_GRID_CAP")
123
+ .ok()
124
+ .and_then(|s| s.parse::<u32>().ok())
125
+ .map(|v| v.max(1))
126
+ }
127
+
128
+ pub(crate) fn plan_fused_launch(
129
+ sm_count: u32,
130
+ cooperative_supported: bool,
131
+ cooperative_grid_limit: u32,
132
+ grid_cap_override: Option<u32>,
133
+ ) -> Result<FusedLaunchPlan, String> {
134
+ let sm_count = sm_count.max(1);
135
+ // 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
136
+ // regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
137
+ // 256 regs/thread which is ample. Compensate with more blocks via
138
+ // cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
139
+ // 1024 works fine, but 256 is safe everywhere.
140
+ let block_dim_x = 256u32;
141
+
142
+ // Cluster launch path: cooperative launch is not required. Keep the probe
143
+ // result for residency estimation only.
144
+ if !cooperative_supported {
145
+ eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
146
+ }
147
+
148
+ // Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
149
+ // Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
150
+ let default_grid_cap = 16u32;
151
+ let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
152
+ let resident_bound = if cooperative_grid_limit > 0 {
153
+ cooperative_grid_limit.max(sm_count * 2)
154
+ } else {
155
+ sm_count * 2
156
+ };
157
+ Ok(FusedLaunchPlan {
158
+ grid_dim_x: resident_bound.min(grid_cap).max(1),
159
+ block_dim_x,
160
+ cooperative_grid_limit: resident_bound,
161
+ sm_count,
162
+ })
163
+ }
164
+
165
+ pub(crate) fn plan_batched_grid_dim(
166
+ grid_dim_x: u32,
167
+ cooperative_grid_limit: u32,
168
+ batch_regions: usize,
169
+ use_cluster: bool,
170
+ ) -> Result<u32, String> {
171
+ if use_cluster {
172
+ return Ok(grid_dim_x.max(1));
173
+ }
174
+
175
+ let batch_regions = batch_regions.max(1) as u32;
176
+ if cooperative_grid_limit == 0 {
177
+ return Err("COOPERATIVE_LAUNCH_TOO_LARGE: cooperative launch limit unavailable".into());
178
+ }
179
+
180
+ let max_grid_x = cooperative_grid_limit / batch_regions;
181
+ if max_grid_x == 0 {
182
+ return Err(format!(
183
+ "COOPERATIVE_LAUNCH_TOO_LARGE: batch_regions={batch_regions} exceeds cooperative_grid_limit={cooperative_grid_limit}"
184
+ ));
185
+ }
186
+
187
+ Ok(grid_dim_x.min(max_grid_x).max(1))
188
+ }
189
+
190
+ pub(super) struct RawFusedKernel {
191
+ module: sys::CUmodule,
192
+ pub(super) function: sys::CUfunction,
193
+ pub(super) function_batched: sys::CUfunction,
194
+ }
195
+
196
+ unsafe impl Send for RawFusedKernel {}
197
+ unsafe impl Sync for RawFusedKernel {}
198
+
199
+ impl Drop for RawFusedKernel {
200
+ fn drop(&mut self) {
201
+ unsafe {
202
+ let _ = result::module::unload(self.module);
203
+ }
204
+ }
205
+ }
206
+
207
+ /// Owns fused-path-only device state:
208
+ /// - per-column inhibition threshold (replaces global top-K)
209
+ /// - ping-pong cell_active/cell_winner bitsets
210
+ /// - step_scratch (n_active, n_unpred per timestep)
211
+ /// - cluster launch capability info
212
+ pub struct FusedState {
213
+ dev: Arc<CudaDevice>,
214
+ pub(super) raw_kernel: RawFusedKernel,
215
+
216
+ pub inhibition_threshold: CudaSlice<f32>,
217
+ pub cell_active_bits_a: CudaSlice<u32>,
218
+ pub cell_active_bits_b: CudaSlice<u32>,
219
+ pub cell_winner_bits_a: CudaSlice<u32>,
220
+ pub cell_winner_bits_b: CudaSlice<u32>,
221
+ pub step_scratch: CudaSlice<u32>, // length 6
222
+
223
+ pub grid_dim_x: u32,
224
+ pub block_dim_x: u32,
225
+ pub cooperative_grid_limit: u32,
226
+ pub iter_counter: u32,
227
+
228
+ /// Hopper cluster launch capability (0 = unsupported).
229
+ pub cluster_info: ClusterInfo,
230
+
231
+ // Config mirror (read-only after init).
232
+ #[allow(dead_code)]
233
+ pub initial_threshold: f32,
234
+ }
235
+
236
+ impl FusedState {
237
+ pub fn new(
238
+ dev: Arc<CudaDevice>,
239
+ n_columns: usize,
240
+ cells_per_column: usize,
241
+ initial_threshold: f32,
242
+ ) -> Result<Self, DriverError> {
243
+ let n_cells = n_columns * cells_per_column;
244
+ assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
245
+ let bits_words = n_cells / 32;
246
+
247
+ let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
248
+ let init_vec = vec![initial_threshold; n_columns];
249
+ dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?;
250
+
251
+ let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
252
+ let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
253
+ let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
254
+ let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
255
+ let step_scratch = dev.alloc_zeros::<u32>(6)?;
256
+
257
+ unsafe {
258
+ result::ctx::set_current(*dev.cu_primary_ctx())?;
259
+ }
260
+ if dev.get_func("htm_fused", "htm_fused_step").is_none() {
261
+ dev.load_ptx(
262
+ Ptx::from_src(PTX_HTM_FUSED),
263
+ "htm_fused",
264
+ &["htm_fused_step", "htm_fused_step_batched"],
265
+ )?;
266
+ }
267
+ let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes");
268
+ let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?;
269
+ let function = unsafe {
270
+ result::module::get_function(module, CString::new("htm_fused_step").unwrap())
271
+ }?;
272
+ let function_batched = unsafe {
273
+ result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap())
274
+ }?;
275
+
276
+ // Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in).
277
+ // Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on
278
+ // every launched kernel function, otherwise cuLaunchKernelEx rejects
279
+ // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
280
+ unsafe {
281
+ let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
282
+ // Ignore errors: older CUDA may lack the attribute, in which case
283
+ // only portable sizes (<= 8) work — plan_fused_launch caps at 8.
284
+ let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
285
+ let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1);
286
+ }
287
+
288
+ // Probe SM count.
289
+ let sm_count = match dev.attribute(
290
+ cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
291
+ ) {
292
+ Ok(v) => v as u32,
293
+ Err(_) => 16u32,
294
+ };
295
+
296
+ // T1: Probe Hopper cluster launch capability.
297
+ let max_cluster_size = match dev.attribute(
298
+ cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
299
+ ) {
300
+ Ok(v) if v > 0 => {
301
+ // H200/sm_90a supports up to 16 blocks per cluster.
302
+ // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
303
+ // Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster).
304
+ 16u32
305
+ }
306
+ _ => 0u32,
307
+ };
308
+ eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
309
+ let cluster_info = ClusterInfo { max_cluster_size };
310
+
311
+ let cooperative_supported = matches!(
312
+ dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
313
+ Ok(v) if v > 0
314
+ );
315
+ let cooperative_grid_limit = if cooperative_supported {
316
+ let blocks_per_sm = unsafe {
317
+ // Must match plan_fused_launch(): the A10G/Ampere-safe fused
318
+ // kernel launch uses 256 threads/block, not the historical
319
+ // 1024-thread Hopper occupancy probe.
320
+ result::occupancy::max_active_block_per_multiprocessor(function, 256, 0)
321
+ }
322
+ .ok()
323
+ .map(|v| v.max(0) as u32)
324
+ .unwrap_or(0);
325
+ sm_count.saturating_mul(blocks_per_sm)
326
+ } else {
327
+ 0
328
+ };
329
+ let launch_plan = plan_fused_launch(
330
+ sm_count,
331
+ cooperative_supported,
332
+ cooperative_grid_limit,
333
+ fused_grid_cap_override(),
334
+ )
335
+ .map_err(|msg| {
336
+ // Surface as a CUDA-ish error so callers can propagate.
337
+ eprintln!("[htm_rust] FATAL: {msg}");
338
+ DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
339
+ })?;
340
+
341
+ eprintln!(
342
+ "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
343
+ launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
344
+ cluster_info.max_cluster_size,
345
+ );
346
+
347
+ Ok(Self {
348
+ dev,
349
+ raw_kernel: RawFusedKernel { module, function, function_batched },
350
+ inhibition_threshold,
351
+ cell_active_bits_a,
352
+ cell_active_bits_b,
353
+ cell_winner_bits_a,
354
+ cell_winner_bits_b,
355
+ step_scratch,
356
+ grid_dim_x: launch_plan.grid_dim_x,
357
+ block_dim_x: launch_plan.block_dim_x,
358
+ cooperative_grid_limit: launch_plan.cooperative_grid_limit,
359
+ iter_counter: 0,
360
+ cluster_info,
361
+ initial_threshold,
362
+ })
363
+ }
364
+
365
+ /// Reset fused state. Called at region.reset().
366
+ pub fn reset(&mut self) -> Result<(), DriverError> {
367
+ self.dev.memset_zeros(&mut self.cell_active_bits_a)?;
368
+ self.dev.memset_zeros(&mut self.cell_active_bits_b)?;
369
+ self.dev.memset_zeros(&mut self.cell_winner_bits_a)?;
370
+ self.dev.memset_zeros(&mut self.cell_winner_bits_b)?;
371
+ self.dev.memset_zeros(&mut self.step_scratch)?;
372
+ // Do NOT reset inhibition_threshold — it's learned state. A hard
373
+ // reset of TM state should NOT forget the sparsity calibration.
374
+ Ok(())
375
+ }
376
+ }
377
+
378
+ /// Launch the fused megakernel. Processes all T timesteps in one kernel.
379
+ ///
380
+ /// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)`
381
+ /// when the device supports cluster launch, otherwise falls back to a plain
382
+ /// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the
383
+ /// entire grid fits in one cluster.
384
+ #[allow(clippy::too_many_arguments)]
385
+ pub fn launch_fused(
386
+ sp: &mut SpatialPoolerGpu,
387
+ tm: &mut TemporalMemoryGpu,
388
+ fused: &mut FusedState,
389
+ inputs_flat: &CudaSlice<u8>,
390
+ cols_out: &mut CudaSlice<u8>,
391
+ anom_out: &mut CudaSlice<f32>,
392
+ t: usize,
393
+ input_bits: usize,
394
+ learn: bool,
395
+ ) -> Result<(), DriverError> {
396
+ // Reset step_scratch before each launch (safe re-entry).
397
+ sp.dev_ref().memset_zeros(&mut fused.step_scratch)?;
398
+
399
+ fused.iter_counter = fused.iter_counter.wrapping_add(1);
400
+
401
+ let cfg = FusedConfig {
402
+ input_bits: input_bits as u32,
403
+ n_columns: sp.n_columns_accessor() as u32,
404
+ synapses_per_col: sp.synapses_per_col_accessor() as u32,
405
+ conn_thr: sp.conn_thr_accessor(),
406
+ sp_inc: sp.inc_accessor(),
407
+ sp_dec: sp.dec_accessor(),
408
+ sparsity_target: sp.sparsity_accessor(),
409
+ duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0),
410
+ thr_adapt_rate: 0.001f32,
411
+ cells_per_column: tm.cells_per_column as u32,
412
+ n_cells: tm.n_cells as u32,
413
+ bits_words: tm.bits_words as u32,
414
+ max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
415
+ synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
416
+ activation_threshold: tm.activation_threshold,
417
+ learning_threshold: tm.learning_threshold,
418
+ max_new_synapses: tm.max_new_synapse_count,
419
+ conn_thr_i16: tm.conn_thr_i16 as i32,
420
+ perm_inc_i16: tm.perm_inc_i16 as i32,
421
+ perm_dec_i16: tm.perm_dec_i16 as i32,
422
+ predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32,
423
+ initial_perm_i16: tm.initial_perm_i16 as i32,
424
+ t: t as u32,
425
+ learn: if learn { 1 } else { 0 },
426
+ iter_seed: fused.iter_counter,
427
+ cooperative_grid_sync: 1,
428
+ };
429
+
430
+ let ptrs = FusedPtrs {
431
+ syn_bit: *sp.syn_bit_accessor().device_ptr(),
432
+ syn_perm: *sp.syn_perm_accessor().device_ptr(),
433
+ boost: *sp.boost_accessor().device_ptr(),
434
+ active_duty: *sp.active_duty_accessor().device_ptr(),
435
+ inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
436
+ seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(),
437
+ seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(),
438
+ syn_presyn: *tm.syn_presyn_accessor().device_ptr(),
439
+ tm_syn_perm: *tm.syn_perm_accessor().device_ptr(),
440
+ cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(),
441
+ cell_active_a: *fused.cell_active_bits_a.device_ptr(),
442
+ cell_active_b: *fused.cell_active_bits_b.device_ptr(),
443
+ cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
444
+ cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
445
+ inputs: *inputs_flat.device_ptr(),
446
+ cols_out: *cols_out.device_ptr(),
447
+ anom_out: *anom_out.device_ptr(),
448
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
449
+ step_scratch: *fused.step_scratch.device_ptr(),
450
+ };
451
+
452
+ let grid_x = fused.grid_dim_x;
453
+ let block_x = fused.block_dim_x;
454
+ let cu_stream = *sp.dev_ref().cu_stream();
455
+ let use_cluster = fused.cluster_info.max_cluster_size > 0;
456
+
457
+ unsafe {
458
+ result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?;
459
+ let mut kernel_params: [*mut std::ffi::c_void; 2] = [
460
+ (&ptrs as *const FusedPtrs).cast_mut().cast(),
461
+ (&cfg as *const FusedConfig).cast_mut().cast(),
462
+ ];
463
+
464
+ if use_cluster {
465
+ // T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION.
466
+ // cluster_dim=(16,1,1) maps the entire single-region grid into one cluster.
467
+ let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
468
+ attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
469
+ attr.value.clusterDim.x = 16;
470
+ attr.value.clusterDim.y = 1;
471
+ attr.value.clusterDim.z = 1;
472
+
473
+ let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
474
+ launch_cfg.gridDimX = grid_x;
475
+ launch_cfg.gridDimY = 1;
476
+ launch_cfg.gridDimZ = 1;
477
+ launch_cfg.blockDimX = block_x;
478
+ launch_cfg.blockDimY = 1;
479
+ launch_cfg.blockDimZ = 1;
480
+ launch_cfg.sharedMemBytes = 0;
481
+ launch_cfg.hStream = cu_stream;
482
+ launch_cfg.numAttrs = 1;
483
+ launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
484
+
485
+ let ret = sys::lib().cuLaunchKernelEx(
486
+ &launch_cfg as *const sys::CUlaunchConfig,
487
+ fused.raw_kernel.function,
488
+ kernel_params.as_mut_ptr(),
489
+ std::ptr::null_mut(),
490
+ );
491
+ if ret != sys::CUresult::CUDA_SUCCESS {
492
+ return Err(DriverError(ret));
493
+ }
494
+ } else {
495
+ // Pre-Hopper: cooperative kernel launch. The fused kernel uses
496
+ // grid.sync() for cross-block synchronization which REQUIRES
497
+ // cuLaunchCooperativeKernel (normal launch silently crashes on
498
+ // the first grid.sync() call).
499
+ let ret = sys::lib().cuLaunchCooperativeKernel(
500
+ fused.raw_kernel.function,
501
+ grid_x, 1, 1,
502
+ block_x, 1, 1,
503
+ 0, // sharedMemBytes
504
+ cu_stream,
505
+ kernel_params.as_mut_ptr(),
506
+ );
507
+ if ret != sys::CUresult::CUDA_SUCCESS {
508
+ return Err(DriverError(ret));
509
+ }
510
+ }
511
+ }
512
+
513
+ Ok(())
514
+ }
515
+
516
+ /// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel
517
+ /// body; each block reads its region's FusedPtrs from a device-side array
518
+ /// indexed by blockIdx.y. All regions share the same config (same
519
+ /// input_bits/n_columns/etc.) so we pass one FusedConfig.
520
+ ///
521
+ /// This breaks through the CUDA cooperative-kernel device-level
522
+ /// serialization: multiple cooperative launches are serialized regardless
523
+ /// of stream, but one cooperative launch with grid.y=B processes all
524
+ /// regions in a single invocation ~B× speedup vs B sequential launches.
525
+ #[allow(clippy::too_many_arguments)]
526
+ /// Low-level raw-pointer entry, called by PyO3 binding which holds the
527
+ /// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live,
528
+ /// uniquely-borrowed region. All regions must be distinct.
529
+ pub(super) fn launch_fused_batched_raw(
530
+ region_ptrs: &[*mut super::HTMRegionGpu],
531
+ inputs_per_region: &[u64],
532
+ cols_per_region: &[u64],
533
+ anom_per_region: &[u64],
534
+ t: usize,
535
+ input_bits: usize,
536
+ learn: bool,
537
+ ) -> Result<(), DriverError> {
538
+ let b = region_ptrs.len();
539
+ assert_eq!(inputs_per_region.len(), b);
540
+ assert_eq!(cols_per_region.len(), b);
541
+ assert_eq!(anom_per_region.len(), b);
542
+ assert!(b >= 1, "need at least one region");
543
+
544
+ // Reset per-region step_scratch before each launch.
545
+ for &rp in region_ptrs.iter() {
546
+ let r = unsafe { &mut *rp };
547
+ let dev = r.sp_gpu.dev_ref().clone();
548
+ dev.memset_zeros(&mut r.fused_state.step_scratch)?;
549
+ r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1);
550
+ }
551
+
552
+ // Shared config — all regions use identical sp/tm parameters.
553
+ let (grid_x, block_x, cooperative_grid_limit, function_batched, cu_stream, cu_ctx) = {
554
+ let r0 = unsafe { &*region_ptrs[0] };
555
+ (
556
+ r0.fused_state.grid_dim_x,
557
+ r0.fused_state.block_dim_x,
558
+ r0.fused_state.cooperative_grid_limit,
559
+ r0.fused_state.raw_kernel.function_batched,
560
+ *r0.sp_gpu.dev_ref().cu_stream(),
561
+ *r0.sp_gpu.dev_ref().cu_primary_ctx(),
562
+ )
563
+ };
564
+
565
+ let cfg = {
566
+ let r = unsafe { &*region_ptrs[0] };
567
+ FusedConfig {
568
+ input_bits: input_bits as u32,
569
+ n_columns: r.sp_gpu.n_columns_accessor() as u32,
570
+ synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
571
+ conn_thr: r.sp_gpu.conn_thr_accessor(),
572
+ sp_inc: r.sp_gpu.inc_accessor(),
573
+ sp_dec: r.sp_gpu.dec_accessor(),
574
+ sparsity_target: r.sp_gpu.sparsity_accessor(),
575
+ duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0),
576
+ thr_adapt_rate: 0.001f32,
577
+ cells_per_column: r.tm_gpu.cells_per_column as u32,
578
+ n_cells: r.tm_gpu.n_cells as u32,
579
+ bits_words: r.tm_gpu.bits_words as u32,
580
+ max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
581
+ synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
582
+ activation_threshold: r.tm_gpu.activation_threshold,
583
+ learning_threshold: r.tm_gpu.learning_threshold,
584
+ max_new_synapses: r.tm_gpu.max_new_synapse_count,
585
+ conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32,
586
+ perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32,
587
+ perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32,
588
+ predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32,
589
+ initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
590
+ t: t as u32,
591
+ learn: if learn { 1 } else { 0 },
592
+ iter_seed: r.fused_state.iter_counter,
593
+ cooperative_grid_sync: 1,
594
+ }
595
+ };
596
+
597
+ // Build B FusedPtrs per-region.
598
+ let ptrs_vec: Vec<FusedPtrs> = (0..b)
599
+ .map(|i| {
600
+ let r = unsafe { &*region_ptrs[i] };
601
+ FusedPtrs {
602
+ syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
603
+ syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
604
+ boost: *r.sp_gpu.boost_accessor().device_ptr(),
605
+ active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
606
+ inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(),
607
+ seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
608
+ seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
609
+ syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
610
+ tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
611
+ cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
612
+ cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(),
613
+ cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(),
614
+ cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(),
615
+ cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(),
616
+ inputs: inputs_per_region[i],
617
+ cols_out: cols_per_region[i],
618
+ anom_out: anom_per_region[i],
619
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
620
+ step_scratch: *r.fused_state.step_scratch.device_ptr(),
621
+ }
622
+ })
623
+ .collect();
624
+
625
+ // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
626
+ // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
627
+ let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone();
628
+ let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?;
629
+ let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr();
630
+
631
+ // T10: Cluster launch for batched regions.
632
+ // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
633
+ // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
634
+ // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
635
+ let use_cluster = {
636
+ let r0 = unsafe { &*region_ptrs[0] };
637
+ r0.fused_state.cluster_info.max_cluster_size > 0
638
+ };
639
+ let grid_x = plan_batched_grid_dim(grid_x, cooperative_grid_limit, b, use_cluster)
640
+ .map_err(|msg| {
641
+ eprintln!("[htm_rust] FATAL: {msg}");
642
+ DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
643
+ })?;
644
+
645
+ unsafe {
646
+ result::ctx::set_current(cu_ctx)?;
647
+ let mut kernel_params: [*mut std::ffi::c_void; 2] = [
648
+ (&ptrs_dev_ptr as *const u64).cast_mut().cast(),
649
+ (&cfg as *const FusedConfig).cast_mut().cast(),
650
+ ];
651
+
652
+ if use_cluster {
653
+ let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
654
+ attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
655
+ attr.value.clusterDim.x = 16;
656
+ attr.value.clusterDim.y = 1;
657
+ attr.value.clusterDim.z = 1;
658
+
659
+ let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
660
+ launch_cfg.gridDimX = grid_x;
661
+ launch_cfg.gridDimY = b as u32;
662
+ launch_cfg.gridDimZ = 1;
663
+ launch_cfg.blockDimX = block_x;
664
+ launch_cfg.blockDimY = 1;
665
+ launch_cfg.blockDimZ = 1;
666
+ launch_cfg.sharedMemBytes = 0;
667
+ launch_cfg.hStream = cu_stream;
668
+ launch_cfg.numAttrs = 1;
669
+ launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
670
+
671
+ let ret = sys::lib().cuLaunchKernelEx(
672
+ &launch_cfg as *const sys::CUlaunchConfig,
673
+ function_batched,
674
+ kernel_params.as_mut_ptr(),
675
+ std::ptr::null_mut(),
676
+ );
677
+ if ret != sys::CUresult::CUDA_SUCCESS {
678
+ return Err(DriverError(ret));
679
+ }
680
+ } else {
681
+ // Pre-Hopper: cooperative kernel launch (grid.sync() requires it).
682
+ let ret = sys::lib().cuLaunchCooperativeKernel(
683
+ function_batched,
684
+ grid_x, b as u32, 1,
685
+ block_x, 1, 1,
686
+ 0, // sharedMemBytes
687
+ cu_stream,
688
+ kernel_params.as_mut_ptr(),
689
+ );
690
+ if ret != sys::CUresult::CUDA_SUCCESS {
691
+ return Err(DriverError(ret));
692
+ }
693
+ }
694
+ }
695
+
696
+ // `ptrs_dev` is a per-call device array consumed by the async kernel.
697
+ // Keep it alive until the kernel has read it; otherwise dropping/freeing
698
+ // it immediately after launch can surface as a later unrelated CUDA error.
699
+ dev.synchronize()?;
700
+
701
+ Ok(())
702
+ }
overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu CHANGED
@@ -1,677 +1,677 @@
1
- // Fused HTM megakernel — SP + TM, all T timesteps in a single launch.
2
- //
3
- // Design rationale:
4
- // - Global top-K column selection requires cross-block synchronization at
5
- // every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
6
- // - Replace with per-column threshold activation using local lateral
7
- // inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
8
- // Threshold is a per-column running-EMA learned scalar that steers the
9
- // column's long-run activation rate toward the global sparsity target.
10
- // - This is biologically grounded (GABAergic local inhibition) and supported
11
- // by HTM theory (duty-cycle boost already drives this loop; we just
12
- // change which lever the EMA pulls).
13
- //
14
- // Launch shape:
15
- // grid = min(device SM count, 16) // hard cap — see below
16
- // block = 1024 threads = 32 warps
17
- // Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
18
- //
19
- // Cross-block coherence:
20
- // - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
21
- // read _b; reversed at odd t.
22
- // - Preferred path: cooperative launch + hardware whole-grid sync.
23
- // - Fallback path: software 3-slot rotating grid barrier for devices/drivers
24
- // that cannot do cooperative launch.
25
- //
26
- // 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
27
- // cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
28
- // 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
29
- // leaving scheduled blocks spinning on the software grid barrier waiting for
30
- // peer blocks that would never run. 16 blocks is below any realistic residency
31
- // floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
32
- // memory bandwidth on the spatial-pooler stage.
33
- //
34
- // Kernel signature uses struct-by-value for pointers and config to stay
35
- // inside cudarc's launch-arg count limit.
36
-
37
- #include <cooperative_groups.h>
38
- #include <cooperative_groups/memcpy_async.h>
39
-
40
- namespace cg = cooperative_groups;
41
-
42
- // Maximum columns owned per cluster-block in DSMEM.
43
- // Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
44
- // At cluster_size=16: supports up to 256*16=4096 columns.
45
- // Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM —
46
- // well under the 228 KB H200 shared-memory cap.
47
- #define COLS_PER_CLUSTER_BLOCK_MAX 256u
48
-
49
- // Maximum input_bits supported by the TMA-multicast staging tile.
50
- // At 32 KB this covers the production SDR width (16384 bits) with 2× headroom.
51
- // Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB —
52
- // well under the 228 KB H200 limit.
53
- //
54
- // Expected speedup from TMA multicast input staging (T9/T11):
55
- // - Without staging: 16 SMs × T × (input_bits GMEM reads per timestep)
56
- // - With staging: 1 TMA DMA per timestep, shared reads from L1 thereafter
57
- // - Theoretical DRAM bandwidth reduction: ~16× on input reads
58
- // - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
59
- #define INPUT_BITS_MAX 32768u
60
-
61
- extern "C" {
62
-
63
- struct FusedPtrs {
64
- unsigned long long syn_bit;
65
- unsigned long long syn_perm;
66
- unsigned long long boost;
67
- unsigned long long active_duty;
68
- unsigned long long inhibition_threshold;
69
- unsigned long long seg_cell_id;
70
- unsigned long long seg_syn_count;
71
- unsigned long long syn_presyn;
72
- unsigned long long tm_syn_perm;
73
- unsigned long long cell_seg_count;
74
- unsigned long long cell_active_a;
75
- unsigned long long cell_active_b;
76
- unsigned long long cell_winner_a;
77
- unsigned long long cell_winner_b;
78
- unsigned long long inputs;
79
- unsigned long long cols_out;
80
- unsigned long long anom_out;
81
- unsigned long long barrier_counters;
82
- unsigned long long step_scratch;
83
- };
84
-
85
- struct FusedConfig {
86
- // SP constants
87
- unsigned int input_bits;
88
- unsigned int n_columns;
89
- unsigned int synapses_per_col;
90
- float conn_thr;
91
- float sp_inc;
92
- float sp_dec;
93
- float sparsity_target;
94
- float duty_alpha;
95
- float thr_adapt_rate;
96
- // TM constants
97
- unsigned int cells_per_column;
98
- unsigned int n_cells;
99
- unsigned int bits_words;
100
- unsigned int max_segments_per_cell;
101
- unsigned int synapses_per_segment;
102
- unsigned int activation_threshold;
103
- unsigned int learning_threshold;
104
- unsigned int max_new_synapses;
105
- int conn_thr_i16;
106
- int perm_inc_i16;
107
- int perm_dec_i16;
108
- int predicted_seg_dec_i16;
109
- int initial_perm_i16;
110
- // Loop constants
111
- unsigned int T;
112
- unsigned int learn;
113
- unsigned int iter_seed;
114
- unsigned int cooperative_grid_sync;
115
- };
116
-
117
- // Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
118
- // Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
119
- //
120
- // cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
121
- // in ~10-40 ns inside the cluster, with no device-level serialization.
122
- // Multiple clusters (one per HTM region) run fully concurrently — bounded
123
- // only by SM count (8 clusters × 16 SMs = 128 ≤ 132 on H200).
124
- //
125
- // The flags / expected / phase / cooperative_grid_sync parameters are kept
126
- // in the signature for call-site compatibility but are unused.
127
- __device__ static inline void fused_grid_barrier(cg::grid_group grid,
128
- unsigned int * /* flags — unused */,
129
- unsigned int /* expected — unused */,
130
- unsigned int /* phase — unused */,
131
- unsigned int /* cooperative_grid_sync — unused */) {
132
- #if __CUDA_ARCH__ >= 900
133
- // Hopper+ : hardware cluster barrier (~10-40 ns)
134
- auto cluster = cg::this_cluster();
135
- cluster.sync();
136
- #else
137
- // Pre-Hopper (sm_80, sm_86, sm_89): grid-level cooperative sync.
138
- // Requires cooperative kernel launch. ~us-ms range, adequate for HTM
139
- // workload (kernel launch frequency is low).
140
- grid.sync();
141
- #endif
142
- }
143
-
144
- __device__ static inline unsigned int warp_sum_u32(unsigned int v) {
145
- for (int off = 16; off > 0; off >>= 1) {
146
- v += __shfl_down_sync(0xffffffffu, v, off);
147
- }
148
- return v;
149
- }
150
-
151
- // Core kernel body — works for both single-region and batched launches.
152
- // Single-region: caller passes the one FusedPtrs struct.
153
- // Batched: each block reads its region's FusedPtrs via blockIdx.y before
154
- // calling this. State is independent per region (each region owns its own
155
- // GPU buffers); grid.sync() is the only cross-block primitive and it
156
- // spans ALL blocks in the grid (harmless over-sync across regions).
157
- __device__ static inline
158
- void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
159
- cg::grid_group grid = cg::this_grid();
160
- // Cast pointers.
161
- const unsigned int * __restrict__ syn_bit = (const unsigned int*)P.syn_bit;
162
- float * __restrict__ syn_perm = (float*)P.syn_perm;
163
- float * __restrict__ boost = (float*)P.boost;
164
- float * __restrict__ active_duty = (float*)P.active_duty;
165
- float * __restrict__ inhibition_threshold = (float*)P.inhibition_threshold;
166
- unsigned int * __restrict__ seg_cell_id = (unsigned int*)P.seg_cell_id;
167
- unsigned int * __restrict__ seg_syn_count = (unsigned int*)P.seg_syn_count;
168
- unsigned int * __restrict__ syn_presyn = (unsigned int*)P.syn_presyn;
169
- short * __restrict__ tm_syn_perm = (short*)P.tm_syn_perm;
170
- unsigned int * __restrict__ cell_seg_count = (unsigned int*)P.cell_seg_count;
171
- unsigned int * __restrict__ cell_active_a = (unsigned int*)P.cell_active_a;
172
- unsigned int * __restrict__ cell_active_b = (unsigned int*)P.cell_active_b;
173
- unsigned int * __restrict__ cell_winner_a = (unsigned int*)P.cell_winner_a;
174
- unsigned int * __restrict__ cell_winner_b = (unsigned int*)P.cell_winner_b;
175
- const unsigned char * __restrict__ inputs = (const unsigned char*)P.inputs;
176
- unsigned char * __restrict__ cols_out = (unsigned char*)P.cols_out;
177
- float * __restrict__ anom_out = (float*)P.anom_out;
178
- unsigned int * __restrict__ barrier_counters = (unsigned int*)P.barrier_counters;
179
- unsigned int * __restrict__ step_scratch = (unsigned int*)P.step_scratch;
180
-
181
- const unsigned int tid = threadIdx.x;
182
- const unsigned int lane = tid & 31u;
183
- const unsigned int warp = tid >> 5;
184
- const unsigned int warps_per_block = blockDim.x >> 5;
185
- const unsigned int gwarp = blockIdx.x * warps_per_block + warp;
186
- const unsigned int n_warps = gridDim.x * warps_per_block;
187
-
188
- const unsigned int n_cols = cfg.n_columns;
189
- const unsigned int col_lo = (gwarp * n_cols) / n_warps;
190
- const unsigned int col_hi = ((gwarp + 1) * n_cols) / n_warps;
191
-
192
- unsigned int phase = 0u;
193
-
194
- // =========================================================
195
- // DSMEM: Cluster-distributed shared memory for hot per-column
196
- // state (inhibition_threshold, boost, active_duty).
197
- //
198
- // On Hopper (sm_90+): Each block in the cluster owns a contiguous
199
- // slice of columns in its own __shared__ arrays. Any block can
200
- // peer-read another block's slice via cluster.map_shared_rank().
201
- //
202
- // On Ampere (sm_86) and other pre-Hopper: No cluster support.
203
- // Read/write directly from/to global memory (inhibition_threshold,
204
- // boost, active_duty device pointers). Slightly higher latency but
205
- // functionally correct.
206
- // =========================================================
207
-
208
- #if __CUDA_ARCH__ >= 900
209
- // Hopper+ cluster path
210
- auto cluster = cg::this_cluster();
211
- const unsigned int cluster_block_rank = cluster.block_rank(); // 0..cluster_size-1
212
- const unsigned int cluster_sz = cluster.num_blocks(); // == gridDim.x (≤16)
213
- #else
214
- // Pre-Hopper: no cluster, each block is independent.
215
- const unsigned int cluster_block_rank = blockIdx.x;
216
- const unsigned int cluster_sz = gridDim.x;
217
- #endif
218
-
219
- // Partition n_cols evenly across cluster blocks.
220
- // Each block owns cols_per_block columns starting at my_col_start.
221
- const unsigned int cols_per_block =
222
- (n_cols + cluster_sz - 1u) / cluster_sz; // ceil div
223
- const unsigned int my_col_start =
224
- cluster_block_rank * cols_per_block;
225
- const unsigned int my_col_end =
226
- (my_col_start + cols_per_block < n_cols)
227
- ? (my_col_start + cols_per_block) : n_cols; // clamp
228
-
229
- #if __CUDA_ARCH__ >= 900
230
- // Cluster-distributed shared memory arrays.
231
- // Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
232
- // Peer blocks address into each other's smem via map_shared_rank.
233
- __shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
234
- __shared__ float s_boost [COLS_PER_CLUSTER_BLOCK_MAX];
235
- __shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
236
- #endif
237
-
238
- // TMA multicast input staging tile (T9) — HOPPER ONLY.
239
- //
240
- // On Hopper: cg::memcpy_async with cluster scope multicasts input to all
241
- // 16 SMs, reducing DRAM traffic by ~16×.
242
- // On Ampere: 32 KB smem allocation exceeds per-block budget when
243
- // cooperatively launched (48 KB total, registers eat the rest). Skip the
244
- // tile entirely — Stage A reads from GMEM directly (original path).
245
- #if __CUDA_ARCH__ >= 900
246
- __shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
247
- #endif
248
-
249
- #if __CUDA_ARCH__ >= 900
250
- // Initial GMEM → smem load (reads state from previous forward call).
251
- // Each block loads only its own slice; tid strides across the slice.
252
- for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
253
- const unsigned int off = c - my_col_start;
254
- s_inhib_thr [off] = inhibition_threshold[c];
255
- s_boost [off] = boost[c];
256
- s_active_duty[off] = active_duty[c];
257
- }
258
-
259
- // All blocks in the cluster must finish loading before any block
260
- // starts reading peer smem inside the T-loop.
261
- cluster.sync();
262
- #else
263
- // Pre-Hopper: no smem caching needed — reads go directly to GMEM.
264
- // Grid sync ensures all blocks have completed Phase 0 init before T-loop.
265
- grid.sync();
266
- #endif
267
-
268
- const unsigned int S = cfg.synapses_per_col;
269
- const unsigned int cpc = cfg.cells_per_column;
270
- const unsigned int SPS = cfg.synapses_per_segment;
271
- const unsigned int MSC = cfg.max_segments_per_cell;
272
-
273
- // Main timestep loop.
274
- for (unsigned int t = 0u; t < cfg.T; t++) {
275
- const unsigned int inp_off = t * cfg.input_bits;
276
- const unsigned int col_base_out = t * n_cols;
277
-
278
- unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
279
- unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
280
- unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
281
- unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;
282
-
283
- // ---- Phase 0: clear curr bitsets for my cell range ----
284
- const unsigned int my_cell_lo = col_lo * cpc;
285
- const unsigned int my_cell_hi = col_hi * cpc;
286
- if (cpc == 32u) {
287
- // Fast path: one word per column.
288
- for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
289
- curr_active[c] = 0u;
290
- curr_winner[c] = 0u;
291
- }
292
- } else {
293
- for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
294
- unsigned int w = cell >> 5;
295
- unsigned int m = 1u << (cell & 31u);
296
- atomicAnd(&curr_active[w], ~m);
297
- atomicAnd(&curr_winner[w], ~m);
298
- }
299
- }
300
-
301
- // Block 0, lane 0, warp 0 resets step-scratch counters.
302
- if (blockIdx.x == 0u && tid == 0u) {
303
- step_scratch[0] = 0u;
304
- step_scratch[1] = 0u;
305
- }
306
-
307
- // ---- BARRIER 1 ----
308
- // Fence: make the above clear-bitsets + scratch writes globally
309
- // visible before peer blocks observe "barrier arrived".
310
- __threadfence();
311
- fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
312
-
313
- // =========================================================
314
- // T9: TMA MULTICAST INPUT STAGING
315
- //
316
- // Issue a single cluster-scope async DMA to broadcast this
317
- // timestep's input slice into s_input_tile across all 16 SMs
318
- // in the cluster simultaneously. On Hopper sm_90a,
319
- // cg::memcpy_async with cluster scope maps to the TMA
320
- // hardware unit (cp.async.bulk.tensor multicast), reducing
321
- // DRAM input traffic by ~16× vs each block fetching its own
322
- // copy from GMEM.
323
- //
324
- // The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
325
- // If the tile is too small (custom large input_bits), we fall
326
- // back to per-thread GMEM reads in Stage A (identical to the
327
- // original path; use_input_tile==false).
328
- //
329
- // Ordering: BARRIER 1 completes before we issue the DMA.
330
- // The DMA completes before Stage A reads s_input_tile.
331
- // =========================================================
332
- #if __CUDA_ARCH__ >= 900
333
- const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
334
- if (use_input_tile) {
335
- auto tb = cg::this_thread_block();
336
- cg::memcpy_async(tb, s_input_tile,
337
- inputs + inp_off,
338
- cfg.input_bits);
339
- cg::wait(tb);
340
- cluster.sync();
341
- }
342
- #else
343
- const bool use_input_tile = false;
344
- #endif
345
-
346
- // =========================================================
347
- // STAGE A: Spatial Pooler
348
- //
349
- // Hot per-column state (boost, inhibition_threshold,
350
- // active_duty) is served from cluster DSMEM rather than
351
- // GMEM for each of the T timesteps. GMEM is written on
352
- // update so state persists across forward calls.
353
- // =========================================================
354
- for (unsigned int c = col_lo; c < col_hi; c++) {
355
- unsigned int base = c * S;
356
- unsigned int local = 0u;
357
- for (unsigned int s = lane; s < S; s += 32u) {
358
- unsigned int b = syn_bit[base + s];
359
- float p = syn_perm[base + s];
360
- // T9: read from cluster-broadcast tile when available;
361
- // fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
362
- #if __CUDA_ARCH__ >= 900
363
- unsigned int inp_byte = use_input_tile
364
- ? (unsigned int)s_input_tile[b]
365
- : (unsigned int)inputs[inp_off + b];
366
- #else
367
- unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
368
- #endif
369
- unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
370
- local += hit;
371
- }
372
- unsigned int overlap = warp_sum_u32(local);
373
- overlap = __shfl_sync(0xffffffffu, overlap, 0);
374
-
375
- // Read boost + threshold for column c.
376
- #if __CUDA_ARCH__ >= 900
377
- // Hopper: read from cluster-distributed shared memory.
378
- const unsigned int owner_block = c / cols_per_block;
379
- const unsigned int owner_offset = c - owner_block * cols_per_block;
380
- float boost_val = cluster.map_shared_rank(s_boost, owner_block)[owner_offset];
381
- float thr = cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset];
382
- #else
383
- // Pre-Hopper: read directly from global memory.
384
- float boost_val = boost[c];
385
- float thr = inhibition_threshold[c];
386
- #endif
387
-
388
- float boosted = (float)overlap * boost_val;
389
- unsigned int is_active = (boosted > thr) ? 1u : 0u;
390
-
391
- if (lane == 0) {
392
- cols_out[col_base_out + c] = (unsigned char)is_active;
393
- if (is_active) {
394
- atomicAdd(&step_scratch[0], 1u);
395
- }
396
- }
397
-
398
- // SP learn (Hebbian) on active columns.
399
- // T9: use tile for input reads here too.
400
- if (cfg.learn && is_active) {
401
- for (unsigned int s = lane; s < S; s += 32u) {
402
- unsigned int b = syn_bit[base + s];
403
- float p = syn_perm[base + s];
404
- #if __CUDA_ARCH__ >= 900
405
- unsigned int inp_byte = use_input_tile
406
- ? (unsigned int)s_input_tile[b]
407
- : (unsigned int)inputs[inp_off + b];
408
- #else
409
- unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
410
- #endif
411
- if (inp_byte != 0u) {
412
- p += cfg.sp_inc;
413
- if (p > 1.0f) p = 1.0f;
414
- } else {
415
- p -= cfg.sp_dec;
416
- if (p < 0.0f) p = 0.0f;
417
- }
418
- syn_perm[base + s] = p;
419
- }
420
- }
421
-
422
- // active_duty EMA + threshold adaptation.
423
- // Writes go to both DSMEM (hot path, Hopper only) and GMEM (persistence).
424
- if (lane == 0) {
425
- #if __CUDA_ARCH__ >= 900
426
- float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
427
- #else
428
- float ad = active_duty[c];
429
- #endif
430
- float sample = is_active ? 1.0f : 0.0f;
431
- ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;
432
-
433
- #if __CUDA_ARCH__ >= 900
434
- // Writeback: peer smem (for next timestep read) + GMEM (persistence).
435
- cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
436
- #endif
437
- active_duty[c] = ad;
438
-
439
- // Threshold steers toward target sparsity.
440
- float err = ad - cfg.sparsity_target;
441
- float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
442
- if (new_thr < 0.1f) new_thr = 0.1f;
443
- if (new_thr > 1000.0f) new_thr = 1000.0f;
444
-
445
- #if __CUDA_ARCH__ >= 900
446
- // Writeback: peer smem (for next timestep read) + GMEM (persistence).
447
- cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
448
- #endif
449
- inhibition_threshold[c] = new_thr;
450
- }
451
- }
452
-
453
- // ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
454
- //
455
- // On Hopper: cluster.sync() ensures all peer smem writes from this
456
- // timestep are visible to all blocks before Stage B / next t.
457
- // On pre-Hopper: no smem peer writes occur (all state in GMEM),
458
- // so no extra sync needed here — the grid barrier below suffices.
459
- #if __CUDA_ARCH__ >= 900
460
- cluster.sync();
461
- #endif
462
-
463
- // ---- BARRIER 2: SP active_mask must be visible before TM reads ----
464
- // Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
465
- // writes to global memory before peers advance past this barrier.
466
- __threadfence();
467
- fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
468
-
469
- // =========================================================
470
- // STAGE B: Temporal Memory
471
- // =========================================================
472
- for (unsigned int c = col_lo; c < col_hi; c++) {
473
- unsigned int col_active = cols_out[col_base_out + c];
474
- if (col_active == 0u) continue;
475
-
476
- unsigned int base_cell = c * cpc;
477
- unsigned int any_predicted = 0u;
478
- unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
479
- unsigned int best_pot_count = 0u;
480
-
481
- for (unsigned int k = 0u; k < cpc; k++) {
482
- unsigned int cell = base_cell + k;
483
- unsigned int n_segs_here = cell_seg_count[cell];
484
- if (n_segs_here > MSC) n_segs_here = MSC;
485
- if (n_segs_here == 0u) continue;
486
-
487
- unsigned int seg_base_id = cell * MSC;
488
- unsigned int cell_is_predictive = 0u;
489
-
490
- for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
491
- unsigned int seg = seg_base_id + ls;
492
- unsigned int n_syn = seg_syn_count[seg];
493
- if (n_syn == 0u) continue;
494
- unsigned int syn_base = seg * SPS;
495
-
496
- unsigned int l_conn = 0u;
497
- unsigned int l_pot = 0u;
498
- for (unsigned int s = lane; s < n_syn; s += 32u) {
499
- unsigned int presyn = syn_presyn[syn_base + s];
500
- unsigned int w = prev_active[presyn >> 5];
501
- unsigned int bit = (w >> (presyn & 31u)) & 1u;
502
- if (bit) {
503
- l_pot += 1u;
504
- int p = (int)tm_syn_perm[syn_base + s];
505
- if (p >= cfg.conn_thr_i16) l_conn += 1u;
506
- }
507
- }
508
- unsigned int tot_conn = warp_sum_u32(l_conn);
509
- unsigned int tot_pot = warp_sum_u32(l_pot);
510
- tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
511
- tot_pot = __shfl_sync(0xffffffffu, tot_pot, 0);
512
-
513
- if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
514
- if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
515
- best_pot_count = tot_pot;
516
- best_seg_id_for_grow = seg;
517
- }
518
-
519
- // Reinforce predicted-and-correct segment.
520
- if (cfg.learn && tot_conn >= cfg.activation_threshold) {
521
- for (unsigned int s = lane; s < n_syn; s += 32u) {
522
- unsigned int presyn = syn_presyn[syn_base + s];
523
- unsigned int w = prev_active[presyn >> 5];
524
- unsigned int bit = (w >> (presyn & 31u)) & 1u;
525
- int p = (int)tm_syn_perm[syn_base + s];
526
- if (bit) {
527
- int np = p + cfg.perm_inc_i16;
528
- if (np > 32767) np = 32767;
529
- tm_syn_perm[syn_base + s] = (short)np;
530
- } else {
531
- int np = p - cfg.perm_dec_i16;
532
- if (np < 0) np = 0;
533
- tm_syn_perm[syn_base + s] = (short)np;
534
- }
535
- }
536
- }
537
- }
538
-
539
- if (cell_is_predictive) {
540
- any_predicted = 1u;
541
- if (lane == 0) {
542
- unsigned int w = cell >> 5;
543
- unsigned int m = 1u << (cell & 31u);
544
- atomicOr(&curr_active[w], m);
545
- atomicOr(&curr_winner[w], m);
546
- }
547
- }
548
- }
549
-
550
- // BURST if no predicted.
551
- if (!any_predicted) {
552
- if (lane == 0) {
553
- for (unsigned int k = 0u; k < cpc; k++) {
554
- unsigned int cell = base_cell + k;
555
- unsigned int w = cell >> 5;
556
- unsigned int m = 1u << (cell & 31u);
557
- atomicOr(&curr_active[w], m);
558
- }
559
- unsigned int win = base_cell;
560
- unsigned int ww = win >> 5;
561
- unsigned int wm = 1u << (win & 31u);
562
- atomicOr(&curr_winner[ww], wm);
563
- atomicAdd(&step_scratch[1], 1u);
564
- }
565
-
566
- if (cfg.learn) {
567
- unsigned int target_seg;
568
- unsigned int existing_syn;
569
- if (best_seg_id_for_grow != 0xFFFFFFFFu) {
570
- // Reuse best matching segment.
571
- target_seg = best_seg_id_for_grow;
572
- existing_syn = seg_syn_count[target_seg];
573
- target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
574
- existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);
575
-
576
- // Reinforce its existing synapses.
577
- unsigned int syn_base = target_seg * SPS;
578
- for (unsigned int s = lane; s < existing_syn; s += 32u) {
579
- unsigned int presyn = syn_presyn[syn_base + s];
580
- unsigned int w = prev_active[presyn >> 5];
581
- unsigned int bit = (w >> (presyn & 31u)) & 1u;
582
- int p = (int)tm_syn_perm[syn_base + s];
583
- if (bit) {
584
- int np = p + cfg.perm_inc_i16;
585
- if (np > 32767) np = 32767;
586
- tm_syn_perm[syn_base + s] = (short)np;
587
- } else {
588
- int np = p - cfg.perm_dec_i16;
589
- if (np < 0) np = 0;
590
- tm_syn_perm[syn_base + s] = (short)np;
591
- }
592
- }
593
- } else {
594
- // Allocate new segment on winner cell (cell 0 of col).
595
- unsigned int new_seg = 0u;
596
- if (lane == 0) {
597
- unsigned int winner_cell = base_cell;
598
- unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
599
- if (slot >= MSC) slot = slot % MSC;
600
- new_seg = winner_cell * MSC + slot;
601
- seg_cell_id[new_seg] = winner_cell;
602
- seg_syn_count[new_seg] = 0u;
603
- }
604
- target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
605
- existing_syn = 0u;
606
- }
607
-
608
- // Grow synapses to prev_winner cells — lane 0 serialized.
609
- unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
610
- unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
611
- if (lane == 0 && max_grow > 0u) {
612
- unsigned int syn_base = target_seg * SPS;
613
- unsigned int grown = 0u;
614
- unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
615
- for (unsigned int w_off = 0u;
616
- w_off < cfg.bits_words && grown < max_grow;
617
- w_off++) {
618
- unsigned int widx = (start_off + w_off) % cfg.bits_words;
619
- unsigned int word = prev_winner[widx];
620
- while (word != 0u && grown < max_grow) {
621
- unsigned int bit_pos = __ffs(word) - 1u;
622
- word &= ~(1u << bit_pos);
623
- unsigned int cell_id = widx * 32u + bit_pos;
624
- if (cell_id >= cfg.n_cells) continue;
625
- bool exists = false;
626
- for (unsigned int es = 0u; es < existing_syn + grown; es++) {
627
- if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
628
- }
629
- if (exists) continue;
630
- unsigned int write_idx = existing_syn + grown;
631
- if (write_idx >= SPS) break;
632
- syn_presyn[syn_base + write_idx] = cell_id;
633
- tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
634
- grown++;
635
- }
636
- }
637
- if (grown > 0u) {
638
- seg_syn_count[target_seg] = existing_syn + grown;
639
- }
640
- }
641
- }
642
- }
643
- }
644
-
645
- // ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
646
- // Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
647
- // seg_syn_count + syn_presyn before peers advance and consume them as
648
- // prev_active/prev_winner at t+1.
649
- __threadfence();
650
- fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
651
-
652
- // Write anomaly for step t.
653
- if (blockIdx.x == 0u && tid == 0u) {
654
- unsigned int total = step_scratch[0];
655
- unsigned int bad = step_scratch[1];
656
- float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
657
- anom_out[t] = anom;
658
- }
659
- }
660
- }
661
-
662
- // Single-region kernel (legacy call site).
663
- __global__ __launch_bounds__(256, 2)
664
- void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
665
- htm_fused_step_body(P, cfg);
666
- }
667
-
668
- // Batched kernel: one cooperative launch for B regions. grid.y = B,
669
- // grid.x = per-region block count. Each block reads its region's
670
- // FusedPtrs from the device array via blockIdx.y.
671
- __global__ __launch_bounds__(256, 2)
672
- void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
673
- const FusedPtrs P = P_arr[blockIdx.y];
674
- htm_fused_step_body(P, cfg);
675
- }
676
-
677
- } // extern "C"
 
1
+ // Fused HTM megakernel — SP + TM, all T timesteps in a single launch.
2
+ //
3
+ // Design rationale:
4
+ // - Global top-K column selection requires cross-block synchronization at
5
+ // every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
6
+ // - Replace with per-column threshold activation using local lateral
7
+ // inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
8
+ // Threshold is a per-column running-EMA learned scalar that steers the
9
+ // column's long-run activation rate toward the global sparsity target.
10
+ // - This is biologically grounded (GABAergic local inhibition) and supported
11
+ // by HTM theory (duty-cycle boost already drives this loop; we just
12
+ // change which lever the EMA pulls).
13
+ //
14
+ // Launch shape:
15
+ // grid = min(device SM count, 16) // hard cap — see below
16
+ // block = 1024 threads = 32 warps
17
+ // Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
18
+ //
19
+ // Cross-block coherence:
20
+ // - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
21
+ // read _b; reversed at odd t.
22
+ // - Preferred path: cooperative launch + hardware whole-grid sync.
23
+ // - Fallback path: software 3-slot rotating grid barrier for devices/drivers
24
+ // that cannot do cooperative launch.
25
+ //
26
+ // 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
27
+ // cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
28
+ // 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
29
+ // leaving scheduled blocks spinning on the software grid barrier waiting for
30
+ // peer blocks that would never run. 16 blocks is below any realistic residency
31
+ // floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
32
+ // memory bandwidth on the spatial-pooler stage.
33
+ //
34
+ // Kernel signature uses struct-by-value for pointers and config to stay
35
+ // inside cudarc's launch-arg count limit.
36
+
37
+ #include <cooperative_groups.h>
38
+ #include <cooperative_groups/memcpy_async.h>
39
+
40
+ namespace cg = cooperative_groups;
41
+
42
+ // Maximum columns owned per cluster-block in DSMEM.
43
+ // Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
44
+ // At cluster_size=16: supports up to 256*16=4096 columns.
45
+ // Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM —
46
+ // well under the 228 KB H200 shared-memory cap.
47
+ #define COLS_PER_CLUSTER_BLOCK_MAX 256u
48
+
49
+ // Maximum input_bits supported by the TMA-multicast staging tile.
50
+ // At 32 KB this covers the production SDR width (16384 bits) with 2× headroom.
51
+ // Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB —
52
+ // well under the 228 KB H200 limit.
53
+ //
54
+ // Expected speedup from TMA multicast input staging (T9/T11):
55
+ // - Without staging: 16 SMs × T × (input_bits GMEM reads per timestep)
56
+ // - With staging: 1 TMA DMA per timestep, shared reads from L1 thereafter
57
+ // - Theoretical DRAM bandwidth reduction: ~16× on input reads
58
+ // - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
59
+ #define INPUT_BITS_MAX 32768u
60
+
61
+ extern "C" {
62
+
63
+ struct FusedPtrs {
64
+ unsigned long long syn_bit;
65
+ unsigned long long syn_perm;
66
+ unsigned long long boost;
67
+ unsigned long long active_duty;
68
+ unsigned long long inhibition_threshold;
69
+ unsigned long long seg_cell_id;
70
+ unsigned long long seg_syn_count;
71
+ unsigned long long syn_presyn;
72
+ unsigned long long tm_syn_perm;
73
+ unsigned long long cell_seg_count;
74
+ unsigned long long cell_active_a;
75
+ unsigned long long cell_active_b;
76
+ unsigned long long cell_winner_a;
77
+ unsigned long long cell_winner_b;
78
+ unsigned long long inputs;
79
+ unsigned long long cols_out;
80
+ unsigned long long anom_out;
81
+ unsigned long long barrier_counters;
82
+ unsigned long long step_scratch;
83
+ };
84
+
85
+ struct FusedConfig {
86
+ // SP constants
87
+ unsigned int input_bits;
88
+ unsigned int n_columns;
89
+ unsigned int synapses_per_col;
90
+ float conn_thr;
91
+ float sp_inc;
92
+ float sp_dec;
93
+ float sparsity_target;
94
+ float duty_alpha;
95
+ float thr_adapt_rate;
96
+ // TM constants
97
+ unsigned int cells_per_column;
98
+ unsigned int n_cells;
99
+ unsigned int bits_words;
100
+ unsigned int max_segments_per_cell;
101
+ unsigned int synapses_per_segment;
102
+ unsigned int activation_threshold;
103
+ unsigned int learning_threshold;
104
+ unsigned int max_new_synapses;
105
+ int conn_thr_i16;
106
+ int perm_inc_i16;
107
+ int perm_dec_i16;
108
+ int predicted_seg_dec_i16;
109
+ int initial_perm_i16;
110
+ // Loop constants
111
+ unsigned int T;
112
+ unsigned int learn;
113
+ unsigned int iter_seed;
114
+ unsigned int cooperative_grid_sync;
115
+ };
116
+
117
+ // Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
118
+ // Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
119
+ //
120
+ // cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
121
+ // in ~10-40 ns inside the cluster, with no device-level serialization.
122
+ // Multiple clusters (one per HTM region) run fully concurrently — bounded
123
+ // only by SM count (8 clusters × 16 SMs = 128 ≤ 132 on H200).
124
+ //
125
+ // The flags / expected / phase / cooperative_grid_sync parameters are kept
126
+ // in the signature for call-site compatibility but are unused.
127
+ __device__ static inline void fused_grid_barrier(cg::grid_group grid,
128
+ unsigned int * /* flags — unused */,
129
+ unsigned int /* expected — unused */,
130
+ unsigned int /* phase — unused */,
131
+ unsigned int /* cooperative_grid_sync — unused */) {
132
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
133
+ // Hopper+ : hardware cluster barrier (~10-40 ns)
134
+ auto cluster = cg::this_cluster();
135
+ cluster.sync();
136
+ #else
137
+ // Pre-Hopper (sm_80, sm_86, sm_89): grid-level cooperative sync.
138
+ // Requires cooperative kernel launch. ~us-ms range, adequate for HTM
139
+ // workload (kernel launch frequency is low).
140
+ grid.sync();
141
+ #endif
142
+ }
143
+
144
+ __device__ static inline unsigned int warp_sum_u32(unsigned int v) {
145
+ for (int off = 16; off > 0; off >>= 1) {
146
+ v += __shfl_down_sync(0xffffffffu, v, off);
147
+ }
148
+ return v;
149
+ }
150
+
151
+ // Core kernel body — works for both single-region and batched launches.
152
+ // Single-region: caller passes the one FusedPtrs struct.
153
+ // Batched: each block reads its region's FusedPtrs via blockIdx.y before
154
+ // calling this. State is independent per region (each region owns its own
155
+ // GPU buffers); grid.sync() is the only cross-block primitive and it
156
+ // spans ALL blocks in the grid (harmless over-sync across regions).
157
+ __device__ static inline
158
+ void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
159
+ cg::grid_group grid = cg::this_grid();
160
+ // Cast pointers.
161
+ const unsigned int * __restrict__ syn_bit = (const unsigned int*)P.syn_bit;
162
+ float * __restrict__ syn_perm = (float*)P.syn_perm;
163
+ float * __restrict__ boost = (float*)P.boost;
164
+ float * __restrict__ active_duty = (float*)P.active_duty;
165
+ float * __restrict__ inhibition_threshold = (float*)P.inhibition_threshold;
166
+ unsigned int * __restrict__ seg_cell_id = (unsigned int*)P.seg_cell_id;
167
+ unsigned int * __restrict__ seg_syn_count = (unsigned int*)P.seg_syn_count;
168
+ unsigned int * __restrict__ syn_presyn = (unsigned int*)P.syn_presyn;
169
+ short * __restrict__ tm_syn_perm = (short*)P.tm_syn_perm;
170
+ unsigned int * __restrict__ cell_seg_count = (unsigned int*)P.cell_seg_count;
171
+ unsigned int * __restrict__ cell_active_a = (unsigned int*)P.cell_active_a;
172
+ unsigned int * __restrict__ cell_active_b = (unsigned int*)P.cell_active_b;
173
+ unsigned int * __restrict__ cell_winner_a = (unsigned int*)P.cell_winner_a;
174
+ unsigned int * __restrict__ cell_winner_b = (unsigned int*)P.cell_winner_b;
175
+ const unsigned char * __restrict__ inputs = (const unsigned char*)P.inputs;
176
+ unsigned char * __restrict__ cols_out = (unsigned char*)P.cols_out;
177
+ float * __restrict__ anom_out = (float*)P.anom_out;
178
+ unsigned int * __restrict__ barrier_counters = (unsigned int*)P.barrier_counters;
179
+ unsigned int * __restrict__ step_scratch = (unsigned int*)P.step_scratch;
180
+
181
+ const unsigned int tid = threadIdx.x;
182
+ const unsigned int lane = tid & 31u;
183
+ const unsigned int warp = tid >> 5;
184
+ const unsigned int warps_per_block = blockDim.x >> 5;
185
+ const unsigned int gwarp = blockIdx.x * warps_per_block + warp;
186
+ const unsigned int n_warps = gridDim.x * warps_per_block;
187
+
188
+ const unsigned int n_cols = cfg.n_columns;
189
+ const unsigned int col_lo = (gwarp * n_cols) / n_warps;
190
+ const unsigned int col_hi = ((gwarp + 1) * n_cols) / n_warps;
191
+
192
+ unsigned int phase = 0u;
193
+
194
+ // =========================================================
195
+ // DSMEM: Cluster-distributed shared memory for hot per-column
196
+ // state (inhibition_threshold, boost, active_duty).
197
+ //
198
+ // On Hopper (sm_90+): Each block in the cluster owns a contiguous
199
+ // slice of columns in its own __shared__ arrays. Any block can
200
+ // peer-read another block's slice via cluster.map_shared_rank().
201
+ //
202
+ // On Ampere (sm_86) and other pre-Hopper: No cluster support.
203
+ // Read/write directly from/to global memory (inhibition_threshold,
204
+ // boost, active_duty device pointers). Slightly higher latency but
205
+ // functionally correct.
206
+ // =========================================================
207
+
208
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
209
+ // Hopper+ cluster path
210
+ auto cluster = cg::this_cluster();
211
+ const unsigned int cluster_block_rank = cluster.block_rank(); // 0..cluster_size-1
212
+ const unsigned int cluster_sz = cluster.num_blocks(); // == gridDim.x (≤16)
213
+ #else
214
+ // Pre-Hopper: no cluster, each block is independent.
215
+ const unsigned int cluster_block_rank = blockIdx.x;
216
+ const unsigned int cluster_sz = gridDim.x;
217
+ #endif
218
+
219
+ // Partition n_cols evenly across cluster blocks.
220
+ // Each block owns cols_per_block columns starting at my_col_start.
221
+ const unsigned int cols_per_block =
222
+ (n_cols + cluster_sz - 1u) / cluster_sz; // ceil div
223
+ const unsigned int my_col_start =
224
+ cluster_block_rank * cols_per_block;
225
+ const unsigned int my_col_end =
226
+ (my_col_start + cols_per_block < n_cols)
227
+ ? (my_col_start + cols_per_block) : n_cols; // clamp
228
+
229
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
230
+ // Cluster-distributed shared memory arrays.
231
+ // Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
232
+ // Peer blocks address into each other's smem via map_shared_rank.
233
+ __shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
234
+ __shared__ float s_boost [COLS_PER_CLUSTER_BLOCK_MAX];
235
+ __shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
236
+ #endif
237
+
238
+ // TMA multicast input staging tile (T9) — HOPPER ONLY.
239
+ //
240
+ // On Hopper: cg::memcpy_async with cluster scope multicasts input to all
241
+ // 16 SMs, reducing DRAM traffic by ~16×.
242
+ // On Ampere: 32 KB smem allocation exceeds per-block budget when
243
+ // cooperatively launched (48 KB total, registers eat the rest). Skip the
244
+ // tile entirely — Stage A reads from GMEM directly (original path).
245
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
246
+ __shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
247
+ #endif
248
+
249
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
250
+ // Initial GMEM → smem load (reads state from previous forward call).
251
+ // Each block loads only its own slice; tid strides across the slice.
252
+ for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
253
+ const unsigned int off = c - my_col_start;
254
+ s_inhib_thr [off] = inhibition_threshold[c];
255
+ s_boost [off] = boost[c];
256
+ s_active_duty[off] = active_duty[c];
257
+ }
258
+
259
+ // All blocks in the cluster must finish loading before any block
260
+ // starts reading peer smem inside the T-loop.
261
+ cluster.sync();
262
+ #else
263
+ // Pre-Hopper: no smem caching needed — reads go directly to GMEM.
264
+ // Grid sync ensures all blocks have completed Phase 0 init before T-loop.
265
+ grid.sync();
266
+ #endif
267
+
268
+ const unsigned int S = cfg.synapses_per_col;
269
+ const unsigned int cpc = cfg.cells_per_column;
270
+ const unsigned int SPS = cfg.synapses_per_segment;
271
+ const unsigned int MSC = cfg.max_segments_per_cell;
272
+
273
+ // Main timestep loop.
274
+ for (unsigned int t = 0u; t < cfg.T; t++) {
275
+ const unsigned int inp_off = t * cfg.input_bits;
276
+ const unsigned int col_base_out = t * n_cols;
277
+
278
+ unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
279
+ unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
280
+ unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
281
+ unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;
282
+
283
+ // ---- Phase 0: clear curr bitsets for my cell range ----
284
+ const unsigned int my_cell_lo = col_lo * cpc;
285
+ const unsigned int my_cell_hi = col_hi * cpc;
286
+ if (cpc == 32u) {
287
+ // Fast path: one word per column.
288
+ for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
289
+ curr_active[c] = 0u;
290
+ curr_winner[c] = 0u;
291
+ }
292
+ } else {
293
+ for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
294
+ unsigned int w = cell >> 5;
295
+ unsigned int m = 1u << (cell & 31u);
296
+ atomicAnd(&curr_active[w], ~m);
297
+ atomicAnd(&curr_winner[w], ~m);
298
+ }
299
+ }
300
+
301
+ // Block 0, lane 0, warp 0 resets step-scratch counters.
302
+ if (blockIdx.x == 0u && tid == 0u) {
303
+ step_scratch[0] = 0u;
304
+ step_scratch[1] = 0u;
305
+ }
306
+
307
+ // ---- BARRIER 1 ----
308
+ // Fence: make the above clear-bitsets + scratch writes globally
309
+ // visible before peer blocks observe "barrier arrived".
310
+ __threadfence();
311
+ fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
312
+
313
+ // =========================================================
314
+ // T9: TMA MULTICAST INPUT STAGING
315
+ //
316
+ // Issue a single cluster-scope async DMA to broadcast this
317
+ // timestep's input slice into s_input_tile across all 16 SMs
318
+ // in the cluster simultaneously. On Hopper sm_90a,
319
+ // cg::memcpy_async with cluster scope maps to the TMA
320
+ // hardware unit (cp.async.bulk.tensor multicast), reducing
321
+ // DRAM input traffic by ~16× vs each block fetching its own
322
+ // copy from GMEM.
323
+ //
324
+ // The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
325
+ // If the tile is too small (custom large input_bits), we fall
326
+ // back to per-thread GMEM reads in Stage A (identical to the
327
+ // original path; use_input_tile==false).
328
+ //
329
+ // Ordering: BARRIER 1 completes before we issue the DMA.
330
+ // The DMA completes before Stage A reads s_input_tile.
331
+ // =========================================================
332
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
333
+ const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
334
+ if (use_input_tile) {
335
+ auto tb = cg::this_thread_block();
336
+ cg::memcpy_async(tb, s_input_tile,
337
+ inputs + inp_off,
338
+ cfg.input_bits);
339
+ cg::wait(tb);
340
+ cluster.sync();
341
+ }
342
+ #else
343
+ const bool use_input_tile = false;
344
+ #endif
345
+
346
+ // =========================================================
347
+ // STAGE A: Spatial Pooler
348
+ //
349
+ // Hot per-column state (boost, inhibition_threshold,
350
+ // active_duty) is served from cluster DSMEM rather than
351
+ // GMEM for each of the T timesteps. GMEM is written on
352
+ // update so state persists across forward calls.
353
+ // =========================================================
354
+ for (unsigned int c = col_lo; c < col_hi; c++) {
355
+ unsigned int base = c * S;
356
+ unsigned int local = 0u;
357
+ for (unsigned int s = lane; s < S; s += 32u) {
358
+ unsigned int b = syn_bit[base + s];
359
+ float p = syn_perm[base + s];
360
+ // T9: read from cluster-broadcast tile when available;
361
+ // fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
362
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
363
+ unsigned int inp_byte = use_input_tile
364
+ ? (unsigned int)s_input_tile[b]
365
+ : (unsigned int)inputs[inp_off + b];
366
+ #else
367
+ unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
368
+ #endif
369
+ unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
370
+ local += hit;
371
+ }
372
+ unsigned int overlap = warp_sum_u32(local);
373
+ overlap = __shfl_sync(0xffffffffu, overlap, 0);
374
+
375
+ // Read boost + threshold for column c.
376
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
377
+ // Hopper: read from cluster-distributed shared memory.
378
+ const unsigned int owner_block = c / cols_per_block;
379
+ const unsigned int owner_offset = c - owner_block * cols_per_block;
380
+ float boost_val = cluster.map_shared_rank(s_boost, owner_block)[owner_offset];
381
+ float thr = cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset];
382
+ #else
383
+ // Pre-Hopper: read directly from global memory.
384
+ float boost_val = boost[c];
385
+ float thr = inhibition_threshold[c];
386
+ #endif
387
+
388
+ float boosted = (float)overlap * boost_val;
389
+ unsigned int is_active = (boosted > thr) ? 1u : 0u;
390
+
391
+ if (lane == 0) {
392
+ cols_out[col_base_out + c] = (unsigned char)is_active;
393
+ if (is_active) {
394
+ atomicAdd(&step_scratch[0], 1u);
395
+ }
396
+ }
397
+
398
+ // SP learn (Hebbian) on active columns.
399
+ // T9: use tile for input reads here too.
400
+ if (cfg.learn && is_active) {
401
+ for (unsigned int s = lane; s < S; s += 32u) {
402
+ unsigned int b = syn_bit[base + s];
403
+ float p = syn_perm[base + s];
404
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
405
+ unsigned int inp_byte = use_input_tile
406
+ ? (unsigned int)s_input_tile[b]
407
+ : (unsigned int)inputs[inp_off + b];
408
+ #else
409
+ unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
410
+ #endif
411
+ if (inp_byte != 0u) {
412
+ p += cfg.sp_inc;
413
+ if (p > 1.0f) p = 1.0f;
414
+ } else {
415
+ p -= cfg.sp_dec;
416
+ if (p < 0.0f) p = 0.0f;
417
+ }
418
+ syn_perm[base + s] = p;
419
+ }
420
+ }
421
+
422
+ // active_duty EMA + threshold adaptation.
423
+ // Writes go to both DSMEM (hot path, Hopper only) and GMEM (persistence).
424
+ if (lane == 0) {
425
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
426
+ float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
427
+ #else
428
+ float ad = active_duty[c];
429
+ #endif
430
+ float sample = is_active ? 1.0f : 0.0f;
431
+ ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;
432
+
433
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
434
+ // Writeback: peer smem (for next timestep read) + GMEM (persistence).
435
+ cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
436
+ #endif
437
+ active_duty[c] = ad;
438
+
439
+ // Threshold steers toward target sparsity.
440
+ float err = ad - cfg.sparsity_target;
441
+ float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
442
+ if (new_thr < 0.1f) new_thr = 0.1f;
443
+ if (new_thr > 1000.0f) new_thr = 1000.0f;
444
+
445
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
446
+ // Writeback: peer smem (for next timestep read) + GMEM (persistence).
447
+ cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
448
+ #endif
449
+ inhibition_threshold[c] = new_thr;
450
+ }
451
+ }
452
+
453
+ // ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
454
+ //
455
+ // On Hopper: cluster.sync() ensures all peer smem writes from this
456
+ // timestep are visible to all blocks before Stage B / next t.
457
+ // On pre-Hopper: no smem peer writes occur (all state in GMEM),
458
+ // so no extra sync needed here — the grid barrier below suffices.
459
+ #if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
460
+ cluster.sync();
461
+ #endif
462
+
463
+ // ---- BARRIER 2: SP active_mask must be visible before TM reads ----
464
+ // Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
465
+ // writes to global memory before peers advance past this barrier.
466
+ __threadfence();
467
+ fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
468
+
469
+ // =========================================================
470
+ // STAGE B: Temporal Memory
471
+ // =========================================================
472
+ for (unsigned int c = col_lo; c < col_hi; c++) {
473
+ unsigned int col_active = cols_out[col_base_out + c];
474
+ if (col_active == 0u) continue;
475
+
476
+ unsigned int base_cell = c * cpc;
477
+ unsigned int any_predicted = 0u;
478
+ unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
479
+ unsigned int best_pot_count = 0u;
480
+
481
+ for (unsigned int k = 0u; k < cpc; k++) {
482
+ unsigned int cell = base_cell + k;
483
+ unsigned int n_segs_here = cell_seg_count[cell];
484
+ if (n_segs_here > MSC) n_segs_here = MSC;
485
+ if (n_segs_here == 0u) continue;
486
+
487
+ unsigned int seg_base_id = cell * MSC;
488
+ unsigned int cell_is_predictive = 0u;
489
+
490
+ for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
491
+ unsigned int seg = seg_base_id + ls;
492
+ unsigned int n_syn = seg_syn_count[seg];
493
+ if (n_syn == 0u) continue;
494
+ unsigned int syn_base = seg * SPS;
495
+
496
+ unsigned int l_conn = 0u;
497
+ unsigned int l_pot = 0u;
498
+ for (unsigned int s = lane; s < n_syn; s += 32u) {
499
+ unsigned int presyn = syn_presyn[syn_base + s];
500
+ unsigned int w = prev_active[presyn >> 5];
501
+ unsigned int bit = (w >> (presyn & 31u)) & 1u;
502
+ if (bit) {
503
+ l_pot += 1u;
504
+ int p = (int)tm_syn_perm[syn_base + s];
505
+ if (p >= cfg.conn_thr_i16) l_conn += 1u;
506
+ }
507
+ }
508
+ unsigned int tot_conn = warp_sum_u32(l_conn);
509
+ unsigned int tot_pot = warp_sum_u32(l_pot);
510
+ tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
511
+ tot_pot = __shfl_sync(0xffffffffu, tot_pot, 0);
512
+
513
+ if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
514
+ if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
515
+ best_pot_count = tot_pot;
516
+ best_seg_id_for_grow = seg;
517
+ }
518
+
519
+ // Reinforce predicted-and-correct segment.
520
+ if (cfg.learn && tot_conn >= cfg.activation_threshold) {
521
+ for (unsigned int s = lane; s < n_syn; s += 32u) {
522
+ unsigned int presyn = syn_presyn[syn_base + s];
523
+ unsigned int w = prev_active[presyn >> 5];
524
+ unsigned int bit = (w >> (presyn & 31u)) & 1u;
525
+ int p = (int)tm_syn_perm[syn_base + s];
526
+ if (bit) {
527
+ int np = p + cfg.perm_inc_i16;
528
+ if (np > 32767) np = 32767;
529
+ tm_syn_perm[syn_base + s] = (short)np;
530
+ } else {
531
+ int np = p - cfg.perm_dec_i16;
532
+ if (np < 0) np = 0;
533
+ tm_syn_perm[syn_base + s] = (short)np;
534
+ }
535
+ }
536
+ }
537
+ }
538
+
539
+ if (cell_is_predictive) {
540
+ any_predicted = 1u;
541
+ if (lane == 0) {
542
+ unsigned int w = cell >> 5;
543
+ unsigned int m = 1u << (cell & 31u);
544
+ atomicOr(&curr_active[w], m);
545
+ atomicOr(&curr_winner[w], m);
546
+ }
547
+ }
548
+ }
549
+
550
+ // BURST if no predicted.
551
+ if (!any_predicted) {
552
+ if (lane == 0) {
553
+ for (unsigned int k = 0u; k < cpc; k++) {
554
+ unsigned int cell = base_cell + k;
555
+ unsigned int w = cell >> 5;
556
+ unsigned int m = 1u << (cell & 31u);
557
+ atomicOr(&curr_active[w], m);
558
+ }
559
+ unsigned int win = base_cell;
560
+ unsigned int ww = win >> 5;
561
+ unsigned int wm = 1u << (win & 31u);
562
+ atomicOr(&curr_winner[ww], wm);
563
+ atomicAdd(&step_scratch[1], 1u);
564
+ }
565
+
566
+ if (cfg.learn) {
567
+ unsigned int target_seg;
568
+ unsigned int existing_syn;
569
+ if (best_seg_id_for_grow != 0xFFFFFFFFu) {
570
+ // Reuse best matching segment.
571
+ target_seg = best_seg_id_for_grow;
572
+ existing_syn = seg_syn_count[target_seg];
573
+ target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
574
+ existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);
575
+
576
+ // Reinforce its existing synapses.
577
+ unsigned int syn_base = target_seg * SPS;
578
+ for (unsigned int s = lane; s < existing_syn; s += 32u) {
579
+ unsigned int presyn = syn_presyn[syn_base + s];
580
+ unsigned int w = prev_active[presyn >> 5];
581
+ unsigned int bit = (w >> (presyn & 31u)) & 1u;
582
+ int p = (int)tm_syn_perm[syn_base + s];
583
+ if (bit) {
584
+ int np = p + cfg.perm_inc_i16;
585
+ if (np > 32767) np = 32767;
586
+ tm_syn_perm[syn_base + s] = (short)np;
587
+ } else {
588
+ int np = p - cfg.perm_dec_i16;
589
+ if (np < 0) np = 0;
590
+ tm_syn_perm[syn_base + s] = (short)np;
591
+ }
592
+ }
593
+ } else {
594
+ // Allocate new segment on winner cell (cell 0 of col).
595
+ unsigned int new_seg = 0u;
596
+ if (lane == 0) {
597
+ unsigned int winner_cell = base_cell;
598
+ unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
599
+ if (slot >= MSC) slot = slot % MSC;
600
+ new_seg = winner_cell * MSC + slot;
601
+ seg_cell_id[new_seg] = winner_cell;
602
+ seg_syn_count[new_seg] = 0u;
603
+ }
604
+ target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
605
+ existing_syn = 0u;
606
+ }
607
+
608
+ // Grow synapses to prev_winner cells — lane 0 serialized.
609
+ unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
610
+ unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
611
+ if (lane == 0 && max_grow > 0u) {
612
+ unsigned int syn_base = target_seg * SPS;
613
+ unsigned int grown = 0u;
614
+ unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
615
+ for (unsigned int w_off = 0u;
616
+ w_off < cfg.bits_words && grown < max_grow;
617
+ w_off++) {
618
+ unsigned int widx = (start_off + w_off) % cfg.bits_words;
619
+ unsigned int word = prev_winner[widx];
620
+ while (word != 0u && grown < max_grow) {
621
+ unsigned int bit_pos = __ffs(word) - 1u;
622
+ word &= ~(1u << bit_pos);
623
+ unsigned int cell_id = widx * 32u + bit_pos;
624
+ if (cell_id >= cfg.n_cells) continue;
625
+ bool exists = false;
626
+ for (unsigned int es = 0u; es < existing_syn + grown; es++) {
627
+ if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
628
+ }
629
+ if (exists) continue;
630
+ unsigned int write_idx = existing_syn + grown;
631
+ if (write_idx >= SPS) break;
632
+ syn_presyn[syn_base + write_idx] = cell_id;
633
+ tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
634
+ grown++;
635
+ }
636
+ }
637
+ if (grown > 0u) {
638
+ seg_syn_count[target_seg] = existing_syn + grown;
639
+ }
640
+ }
641
+ }
642
+ }
643
+ }
644
+
645
+ // ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
646
+ // Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
647
+ // seg_syn_count + syn_presyn before peers advance and consume them as
648
+ // prev_active/prev_winner at t+1.
649
+ __threadfence();
650
+ fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
651
+
652
+ // Write anomaly for step t.
653
+ if (blockIdx.x == 0u && tid == 0u) {
654
+ unsigned int total = step_scratch[0];
655
+ unsigned int bad = step_scratch[1];
656
+ float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
657
+ anom_out[t] = anom;
658
+ }
659
+ }
660
+ }
661
+
662
+ // Single-region kernel (legacy call site).
663
+ __global__ __launch_bounds__(256, 2)
664
+ void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
665
+ htm_fused_step_body(P, cfg);
666
+ }
667
+
668
+ // Batched kernel: one cooperative launch for B regions. grid.y = B,
669
+ // grid.x = per-region block count. Each block reads its region's
670
+ // FusedPtrs from the device array via blockIdx.y.
671
+ __global__ __launch_bounds__(256, 2)
672
+ void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
673
+ const FusedPtrs P = P_arr[blockIdx.y];
674
+ htm_fused_step_body(P, cfg);
675
+ }
676
+
677
+ } // extern "C"
overlay/htm_rust/src/gpu/tests.rs CHANGED
@@ -1,643 +1,663 @@
1
- //! Parity tests: GPU SP vs CPU SP reference.
2
- //!
3
- //! With matching seeds the two should produce bit-identical active-column sets
4
- //! when `learn=false`, and remain bit-identical over repeated `learn=true`
5
- //! steps because the Hebbian update is deterministic (no RNG once initialised).
6
- //!
7
- //! Run with: cargo test --release --features gpu
8
-
9
- #![cfg(test)]
10
- #![cfg(feature = "gpu")]
11
-
12
- use crate::sp::{SpatialPooler, SpatialPoolerConfig};
13
- use crate::gpu::sp_gpu::SpatialPoolerGpu;
14
- use crate::gpu::tm_gpu::TemporalMemoryGpu;
15
- use crate::gpu::fused::{
16
- launch_fused, plan_fused_launch, FusedState,
17
- };
18
- use cudarc::driver::CudaSlice;
19
- use rand::{Rng, SeedableRng};
20
- use rand_xoshiro::Xoshiro256PlusPlus;
21
-
22
- fn make_sdr(rng: &mut Xoshiro256PlusPlus, bits: usize, sparsity: f32) -> Vec<u8> {
23
- let on = ((sparsity * bits as f32) as usize).max(1);
24
- let mut v = vec![0u8; bits];
25
- let mut placed = 0;
26
- while placed < on {
27
- let i = rng.gen_range(0..bits);
28
- if v[i] == 0 {
29
- v[i] = 1;
30
- placed += 1;
31
- }
32
- }
33
- v
34
- }
35
-
36
- #[test]
37
- fn gpu_sp_matches_cpu_no_learn() {
38
- let cfg = SpatialPoolerConfig::default();
39
- let bits = cfg.input_bits;
40
- let mut cpu = SpatialPooler::new(
41
- SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
42
- 1234,
43
- );
44
- let cpu_for_gpu = SpatialPooler::new(
45
- SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
46
- 1234,
47
- );
48
- let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu)
49
- .expect("gpu init (CUDA device available)");
50
- gpu.set_strict_parity(true);
51
-
52
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
53
- for step in 0..20 {
54
- let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
55
- let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
56
-
57
- let cpu_active: Vec<u32> = cpu.compute(&sdr_bool, false);
58
- let gpu_active: Vec<u32> = gpu.compute(&sdr_u8, false).expect("gpu compute");
59
-
60
- assert_eq!(
61
- cpu_active, gpu_active,
62
- "mismatch at step {step}: len cpu={} gpu={}",
63
- cpu_active.len(), gpu_active.len()
64
- );
65
- }
66
- }
67
-
68
- #[test]
69
- fn gpu_sp_matches_cpu_with_learn() {
70
- let cfg = SpatialPoolerConfig::default();
71
- let bits = cfg.input_bits;
72
- let mut cpu = SpatialPooler::new(
73
- SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
74
- 5678,
75
- );
76
- let cpu_for_gpu = SpatialPooler::new(
77
- SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
78
- 5678,
79
- );
80
- let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
81
- gpu.set_strict_parity(true);
82
-
83
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
84
- for step in 0..50 {
85
- let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
86
- let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
87
-
88
- let cpu_active = cpu.compute(&sdr_bool, true);
89
- let gpu_active = gpu.compute(&sdr_u8, true).expect("gpu compute");
90
-
91
- assert_eq!(
92
- cpu_active, gpu_active,
93
- "mismatch at step {step} with learning"
94
- );
95
- }
96
- }
97
-
98
- #[test]
99
- fn gpu_tm_anomaly_decays_on_repeating_sequence() {
100
- // End-to-end GPU pipeline: SP feeds TM; repeating SDR sequence should drive
101
- // anomaly down over time.
102
- use crate::gpu::HTMRegionGpu; // not pyclass methods; use internal constructor via Rust
103
- // Easier: replicate the pipeline directly with SP + TM.
104
-
105
- let cfg = SpatialPoolerConfig::default();
106
- let bits = cfg.input_bits;
107
- let n_cols = cfg.n_columns;
108
- let cells_per_col = 32usize;
109
-
110
- let cpu_for_gpu = SpatialPooler::new(SpatialPoolerConfig::default(), 314);
111
- let mut sp = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
112
- let dev = sp.dev_ref().clone();
113
- let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col)
114
- .expect("gpu tm init");
115
- tm.reset().expect("tm reset");
116
-
117
- // Build 3 fixed SDRs, feed them in a repeating sequence.
118
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
119
- let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
120
- let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
121
-
122
- // Warm up SP so columns are stable per symbol.
123
- for _ in 0..100 {
124
- for s in &seqs {
125
- let _ = sp.compute(s, true).expect("sp compute");
126
- }
127
- }
128
-
129
- // Build a long input buffer: 100 repetitions of [A,B,C] = 300 steps.
130
- let repeats = 100usize;
131
- let t = repeats * 3;
132
- let mut inputs_flat = vec![0u8; t * bits];
133
- for r in 0..repeats {
134
- for (i, s) in seqs.iter().enumerate() {
135
- let off = (r * 3 + i) * bits;
136
- inputs_flat[off..off + bits].copy_from_slice(s);
137
- }
138
- }
139
- let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod");
140
-
141
- let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc cols");
142
- let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
143
-
144
- sp.step_batch_with_tm(
145
- &inputs_dev,
146
- t,
147
- bits,
148
- true,
149
- &mut cols_dev,
150
- &mut anom_dev,
151
- &mut tm,
152
- ).expect("step_batch_with_tm");
153
-
154
- let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
155
- let cols: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
156
-
157
- // Active column count per step must equal k for every step.
158
- let k = ((cfg.sparsity * n_cols as f32).round() as usize).max(1);
159
- for ti in 0..t {
160
- let step_slice = &cols[ti * n_cols..(ti + 1) * n_cols];
161
- let n_on = step_slice.iter().filter(|&&b| b != 0).count();
162
- assert_eq!(n_on, k, "step {ti} has {n_on} active cols, expected {k}");
163
- }
164
-
165
- // First repetition: anomaly should be near 1.0 (nothing predicted).
166
- let early_avg: f32 = anom[3..9].iter().sum::<f32>() / 6.0;
167
- // Last repetitions: anomaly should be noticeably lower.
168
- let late_avg: f32 = anom[(t - 9)..t].iter().sum::<f32>() / 9.0;
169
- eprintln!("gpu tm: early anomaly = {early_avg:.3}, late = {late_avg:.3}");
170
- assert!(
171
- late_avg < early_avg,
172
- "GPU TM should reduce anomaly on repeating sequence: early={early_avg:.3}, late={late_avg:.3}"
173
- );
174
- }
175
-
176
- /// Cluster-sync smoke test: verifies that the fused megakernel (which relies on
177
- /// hardware `cluster::sync()` / grid-barrier on H100/H200 Hopper) completes
178
- /// without deadlock when called with real HTM state, and that output shapes are
179
- /// sane (no NaN / Inf in anomaly scores, active-column count in plausible range).
180
- ///
181
- /// This is an *integration* test, not a synthetic micro-benchmark: it exercises
182
- /// exactly the same `launch_fused` code path used in production, so any
183
- /// deadlock in the cooperative-grid or DLB barrier would surface here.
184
- ///
185
- /// Skips gracefully (with an eprintln) when no GPU is available — the test
186
- /// binary returns exit-code 0 in that case so CI still passes.
187
- #[test]
188
- fn cluster_sync_smoke_test() {
189
- // Build a tiny HTM region (1024 inputs, 256 columns, 4 cells/column).
190
- // This keeps VRAM usage minimal while still exercising all kernel paths.
191
- let input_bits = 1024usize;
192
- let n_columns = 256usize;
193
- let cells_per_col = 4usize;
194
-
195
- // Probe cooperative launch attribute before doing any real work.
196
- // CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 223 (added in CUDA 11.8 for Hopper).
197
- // cudarc exposes raw attribute querying; we check cooperative launch (98)
198
- // as the guard — cluster launch is a superset and not separately probed
199
- // here since cudarc doesn't expose attribute 223 symbolically yet.
200
- // On pre-Hopper hardware the DLB barrier path is used instead and the
201
- // test still validates no deadlock on that path.
202
-
203
- let make_cfg = || SpatialPoolerConfig {
204
- input_bits,
205
- n_columns,
206
- sparsity: 0.04, // ~10 active cols out of 256
207
- ..SpatialPoolerConfig::default()
208
- };
209
-
210
- let cpu_ref = SpatialPooler::new(make_cfg(), 42);
211
-
212
- let mut sp = match SpatialPoolerGpu::from_cpu(&cpu_ref) {
213
- Ok(sp) => sp,
214
- Err(e) => {
215
- eprintln!("[cluster_sync_smoke_test] No GPU available ({e:?}) — skipping");
216
- return;
217
- }
218
- };
219
-
220
- let dev = sp.dev_ref().clone();
221
-
222
- // Check cooperative launch support; skip with a clear message if absent.
223
- let cooperative_ok = matches!(
224
- dev.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
225
- Ok(v) if v > 0
226
- );
227
- if !cooperative_ok {
228
- eprintln!("[cluster_sync_smoke_test] CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH=0 — DLB path only, still running test");
229
- // We continue — the DLB path is the production fallback and must not deadlock either.
230
- }
231
-
232
- let mut tm = match TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_col) {
233
- Ok(tm) => tm,
234
- Err(e) => {
235
- eprintln!("[cluster_sync_smoke_test] TemporalMemoryGpu::new failed ({e:?}) — skipping");
236
- return;
237
- }
238
- };
239
- tm.reset().expect("tm reset");
240
-
241
- let mut fused_st: FusedState = match FusedState::new(
242
- dev.clone(),
243
- n_columns,
244
- cells_per_col,
245
- sp.initial_threshold_estimate(),
246
- ) {
247
- Ok(f) => f,
248
- Err(e) => {
249
- eprintln!("[cluster_sync_smoke_test] FusedState::new failed ({e:?}) — skipping");
250
- return;
251
- }
252
- };
253
- fused_st.reset().expect("fused reset");
254
-
255
- // Build T=4 timesteps of all-zero input SDRs.
256
- let t = 4usize;
257
- let inputs_flat = vec![0u8; t * input_bits];
258
- let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod inputs");
259
-
260
- let mut cols_dev = dev.alloc_zeros::<u8>(t * n_columns).expect("alloc cols");
261
- let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
262
-
263
- // Execute with a 2-second timeout guard via a thread. If the kernel
264
- // deadlocks, the parent test process times out and the CI job reports
265
- // failure — we can't cancel a live CUDA kernel from Rust, but the
266
- // launch_fused call itself must return within this window on any sane GPU.
267
- //
268
- // We run the kernel inline (not in a separate thread) because CUDA contexts
269
- // are not safely shareable across threads without explicit multi-threading
270
- // setup. The 2-second bound is enforced implicitly: if the kernel deadlocks,
271
- // the test binary will hang and the CI timeout (typically 5 min) will kill it.
272
- // For local dev, the deadlock would be immediately obvious.
273
-
274
- launch_fused(
275
- &mut sp,
276
- &mut tm,
277
- &mut fused_st,
278
- &inputs_dev,
279
- &mut cols_dev,
280
- &mut anom_dev,
281
- t,
282
- input_bits,
283
- false, // learn=false for determinism
284
- ).expect("launch_fused (cluster_sync_smoke_test): deadlock or CUDA error");
285
-
286
- dev.synchronize().expect("device sync after launch_fused");
287
-
288
- // --- Correctness assertions ---
289
-
290
- let cols_host: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
291
- let anom_host: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
292
-
293
- // Output buffers must be exactly the right size.
294
- assert_eq!(cols_host.len(), t * n_columns, "cols buffer size mismatch");
295
- assert_eq!(anom_host.len(), t, "anom buffer size mismatch");
296
-
297
- // Anomaly scores must be finite (NaN/Inf indicates numerical blow-up).
298
- for (i, &a) in anom_host.iter().enumerate() {
299
- assert!(a.is_finite(), "anomaly[{i}] is not finite: {a}");
300
- assert!(a >= 0.0 && a <= 1.0, "anomaly[{i}] out of [0,1]: {a}");
301
- }
302
-
303
- // Active-column count per step: threshold-based inhibition, so 0 is
304
- // possible on cold start (before thresholds calibrate), but we assert
305
- // <= n_columns to catch buffer overruns or completely wrong output.
306
- for ti in 0..t {
307
- let n_on = cols_host[ti * n_columns..(ti + 1) * n_columns]
308
- .iter()
309
- .filter(|&&b| b != 0)
310
- .count();
311
- assert!(
312
- n_on <= n_columns,
313
- "step {ti}: active columns {n_on} > n_columns {n_columns} (buffer overrun?)"
314
- );
315
- }
316
-
317
- eprintln!(
318
- "[cluster_sync_smoke_test] PASSED: T={t}, n_cols={n_columns}, \
319
- input_bits={input_bits}, cooperative_supported={cooperative_ok}, \
320
- anom={anom_host:?}"
321
- );
322
- }
323
-
324
- /// Parity check: the CAI zero-copy path (`step_many_cuda`) must produce
325
- /// bit-identical outputs to the numpy H2D/D2H path (`step_batch_with_tm`),
326
- /// since the kernel pipeline is the same — only the I/O wrapping changes.
327
- /// We skip the PyO3 CAI dict plumbing here and test the underlying
328
- /// ManuallyDrop + upgrade_device_ptr pattern directly.
329
- #[test]
330
- fn gpu_cuda_vs_numpy_parity() {
331
- use std::mem::ManuallyDrop;
332
-
333
- let cfg = SpatialPoolerConfig::default();
334
- let bits = cfg.input_bits;
335
- let n_cols = cfg.n_columns;
336
- let cells_per_col = 32usize;
337
-
338
- // Build two identical (SP, TM) pairs from the same seed.
339
- let build = || -> (SpatialPoolerGpu, TemporalMemoryGpu) {
340
- let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271828);
341
- let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu init");
342
- let dev = sp.dev_ref().clone();
343
- let mut tm = TemporalMemoryGpu::new(dev, n_cols, cells_per_col).expect("tm init");
344
- tm.reset().expect("tm reset");
345
- (sp, tm)
346
- };
347
-
348
- // Deterministic SDR sequence.
349
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
350
- let t = 32usize;
351
- let mut inputs_flat = vec![0u8; t * bits];
352
- for i in 0..t {
353
- let sdr = make_sdr(&mut rng, bits, 0.02);
354
- inputs_flat[i * bits..(i + 1) * bits].copy_from_slice(&sdr);
355
- }
356
-
357
- // ---- Path A: owned CudaSlice (numpy-equivalent path) ----
358
- let (mut sp_a, mut tm_a) = build();
359
- let dev_a = sp_a.dev_ref().clone();
360
- let inputs_a: CudaSlice<u8> = dev_a.htod_sync_copy(&inputs_flat).expect("htod");
361
- let mut cols_a = dev_a.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_a");
362
- let mut anom_a = dev_a.alloc_zeros::<f32>(t).expect("alloc anom_a");
363
- sp_a.step_batch_with_tm(&inputs_a, t, bits, false, &mut cols_a, &mut anom_a, &mut tm_a)
364
- .expect("owned step_batch_with_tm");
365
- dev_a.synchronize().expect("sync a");
366
- let cols_a_host: Vec<u8> = dev_a.dtoh_sync_copy(&cols_a).expect("d2h cols_a");
367
- let anom_a_host: Vec<f32> = dev_a.dtoh_sync_copy(&anom_a).expect("d2h anom_a");
368
-
369
- // ---- Path B: borrowed device pointers via upgrade_device_ptr ----
370
- // We allocate fresh owned CudaSlices on a fresh device, then take their
371
- // raw ptrs and re-wrap as ManuallyDrop borrowed views — mimicking what
372
- // `step_many_cuda` does with torch-owned CUDA memory.
373
- let (mut sp_b, mut tm_b) = build();
374
- let dev_b = sp_b.dev_ref().clone();
375
- let inputs_b_owned: CudaSlice<u8> = dev_b.htod_sync_copy(&inputs_flat).expect("htod");
376
- let cols_b_owned = dev_b.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_b");
377
- let anom_b_owned = dev_b.alloc_zeros::<f32>(t).expect("alloc anom_b");
378
-
379
- // Extract raw CUdeviceptrs (and leak the owners so their Drop doesn't free).
380
- let inputs_ptr = inputs_b_owned.leak();
381
- let cols_ptr = cols_b_owned.leak();
382
- let anom_ptr = anom_b_owned.leak();
383
-
384
- // Re-wrap as borrowed views.
385
- let inputs_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) });
386
- let mut cols_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) });
387
- let mut anom_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) });
388
-
389
- sp_b.step_batch_with_tm(&inputs_b, t, bits, false, &mut cols_b, &mut anom_b, &mut tm_b)
390
- .expect("borrowed step_batch_with_tm");
391
- dev_b.synchronize().expect("sync b");
392
- // `ManuallyDrop` doesn't auto-coerce to `&CudaSlice<T>` for the DevicePtr
393
- // trait bound on `dtoh_sync_copy`; explicit deref.
394
- let cols_b_host: Vec<u8> = dev_b.dtoh_sync_copy(&*cols_b).expect("d2h cols_b");
395
- let anom_b_host: Vec<f32> = dev_b.dtoh_sync_copy(&*anom_b).expect("d2h anom_b");
396
-
397
- // Re-own so Drop actually frees (we leaked above).
398
- let _inputs_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) };
399
- let _cols_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) };
400
- let _anom_owned_again = unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) };
401
-
402
- assert_eq!(cols_a_host, cols_b_host, "active-column mask diverges between numpy and CAI paths");
403
- assert_eq!(anom_a_host.len(), anom_b_host.len());
404
- for (i, (a, b)) in anom_a_host.iter().zip(anom_b_host.iter()).enumerate() {
405
- // Anomaly is a pure division of integer counts — bit-exact expected.
406
- assert!((a - b).abs() < 1e-7, "anomaly mismatch at step {i}: a={a} b={b}");
407
- }
408
- }
409
-
410
- /// Fused kernel: threshold activation should converge to near target sparsity
411
- /// after a short warmup. Acceptance: mean activation rate per step lands in
412
- /// [0.3*target, 2.5*target] after 500-step warmup. Because the threshold
413
- /// starts conservative (=2.0) and the per-column adaptation rate is slow
414
- /// (0.001), we allow a generous band — the test asserts directional
415
- /// convergence toward the target, not tight matching.
416
- #[test]
417
- fn gpu_threshold_converges_to_sparsity() {
418
- let cfg = SpatialPoolerConfig::default();
419
- let bits = cfg.input_bits;
420
- let n_cols = cfg.n_columns;
421
- let cells_per_col = 32usize;
422
- let target = cfg.sparsity; // 0.02 = 40 cols expected
423
-
424
- let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 111);
425
- let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
426
- let dev = sp.dev_ref().clone();
427
- let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
428
- let mut fused = FusedState::new(
429
- dev.clone(),
430
- n_cols,
431
- cells_per_col,
432
- sp.initial_threshold_estimate(),
433
- ).expect("fused init");
434
- tm.reset().expect("tm reset");
435
- fused.reset().expect("fused reset");
436
-
437
- // Warmup: 1000 random 2%-sparse SDRs.
438
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
439
- let t_warm = 1000usize;
440
- let mut inputs = vec![0u8; t_warm * bits];
441
- for ti in 0..t_warm {
442
- let sdr = make_sdr(&mut rng, bits, 0.02);
443
- inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
444
- }
445
- let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod");
446
- let mut cols_dev = dev.alloc_zeros::<u8>(t_warm * n_cols).expect("alloc cols");
447
- let mut anom_dev = dev.alloc_zeros::<f32>(t_warm).expect("alloc anom");
448
- launch_fused(
449
- &mut sp, &mut tm, &mut fused,
450
- &inputs_dev, &mut cols_dev, &mut anom_dev,
451
- t_warm, bits, true,
452
- ).expect("warmup launch");
453
- dev.synchronize().expect("sync");
454
-
455
- // Measurement pass: another 200 steps, measure mean activation.
456
- let t_meas = 200usize;
457
- let mut meas_inputs = vec![0u8; t_meas * bits];
458
- for ti in 0..t_meas {
459
- let sdr = make_sdr(&mut rng, bits, 0.02);
460
- meas_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
461
- }
462
- let meas_dev: CudaSlice<u8> = dev.htod_sync_copy(&meas_inputs).expect("htod meas");
463
- let mut meas_cols = dev.alloc_zeros::<u8>(t_meas * n_cols).expect("alloc meas cols");
464
- let mut meas_anom = dev.alloc_zeros::<f32>(t_meas).expect("alloc meas anom");
465
- launch_fused(
466
- &mut sp, &mut tm, &mut fused,
467
- &meas_dev, &mut meas_cols, &mut meas_anom,
468
- t_meas, bits, true,
469
- ).expect("meas launch");
470
- dev.synchronize().expect("sync meas");
471
-
472
- let cols_host: Vec<u8> = dev.dtoh_sync_copy(&meas_cols).expect("d2h");
473
- let mut step_counts = Vec::with_capacity(t_meas);
474
- for ti in 0..t_meas {
475
- let n_on = cols_host[ti*n_cols..(ti+1)*n_cols]
476
- .iter().filter(|&&b| b != 0).count();
477
- step_counts.push(n_on);
478
- }
479
- let mean_active: f64 = step_counts.iter().map(|&c| c as f64).sum::<f64>()
480
- / (t_meas as f64);
481
- let target_active = target as f64 * n_cols as f64;
482
- eprintln!(
483
- "threshold-activation convergence: mean_active/step = {mean_active:.1} \
484
- (target = {target_active:.1})"
485
- );
486
- // Very generous band — we just want to confirm the threshold loop is
487
- // functioning (not diverged to 0 or to all-active).
488
- assert!(
489
- mean_active >= 0.25 * target_active && mean_active <= 4.0 * target_active,
490
- "mean active {mean_active:.1} outside [0.25x, 4x] of target {target_active:.1}"
491
- );
492
- }
493
-
494
- /// Fused kernel: TM should learn a repeating sequence — anomaly decays.
495
- #[test]
496
- fn gpu_fused_tm_anomaly_decays_on_repeating_sequence() {
497
- let cfg = SpatialPoolerConfig::default();
498
- let bits = cfg.input_bits;
499
- let n_cols = cfg.n_columns;
500
- let cells_per_col = 32usize;
501
-
502
- let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271);
503
- let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
504
- let dev = sp.dev_ref().clone();
505
- let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
506
- let mut fused = FusedState::new(
507
- dev.clone(),
508
- n_cols,
509
- cells_per_col,
510
- sp.initial_threshold_estimate(),
511
- ).expect("fused init");
512
- tm.reset().expect("tm reset");
513
- fused.reset().expect("fused reset");
514
-
515
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
516
- let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
517
- let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
518
-
519
- // Warmup SP threshold calibration with random SDRs first.
520
- let warm = 300usize;
521
- let mut warm_inputs = vec![0u8; warm * bits];
522
- for ti in 0..warm {
523
- let sdr = make_sdr(&mut rng, bits, 0.02);
524
- warm_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
525
- }
526
- let warm_dev: CudaSlice<u8> = dev.htod_sync_copy(&warm_inputs).expect("htod warm");
527
- let mut warm_cols = dev.alloc_zeros::<u8>(warm * n_cols).expect("alloc warm cols");
528
- let mut warm_anom = dev.alloc_zeros::<f32>(warm).expect("alloc warm anom");
529
- launch_fused(
530
- &mut sp, &mut tm, &mut fused,
531
- &warm_dev, &mut warm_cols, &mut warm_anom,
532
- warm, bits, true,
533
- ).expect("warm launch");
534
- dev.synchronize().expect("sync warm");
535
-
536
- // Feed repeating A,B,C sequence for 100 reps.
537
- let repeats = 100usize;
538
- let t = repeats * 3;
539
- let mut inputs = vec![0u8; t * bits];
540
- for r in 0..repeats {
541
- for (i, s) in seqs.iter().enumerate() {
542
- let off = (r*3 + i) * bits;
543
- inputs[off..off+bits].copy_from_slice(s);
544
- }
545
- }
546
- let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod rep");
547
- let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc rep cols");
548
- let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc rep anom");
549
- launch_fused(
550
- &mut sp, &mut tm, &mut fused,
551
- &inputs_dev, &mut cols_dev, &mut anom_dev,
552
- t, bits, true,
553
- ).expect("rep launch");
554
- dev.synchronize().expect("sync rep");
555
-
556
- let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
557
- let early_avg: f32 = anom[3..12].iter().sum::<f32>() / 9.0;
558
- let late_avg: f32 = anom[(t-9)..t].iter().sum::<f32>() / 9.0;
559
- eprintln!("fused TM anomaly: early={early_avg:.3} late={late_avg:.3}");
560
- assert!(
561
- late_avg < early_avg,
562
- "anomaly must decay: early={early_avg:.3} late={late_avg:.3}"
563
- );
564
- assert!(
565
- late_avg < 0.5,
566
- "late anomaly must be < 0.5 (got {late_avg:.3})"
567
- );
568
- }
569
-
570
- #[test]
571
- fn gpu_sp_yields_k_winners() {
572
- let cfg = SpatialPoolerConfig::default();
573
- let bits = cfg.input_bits;
574
- let n = cfg.n_columns;
575
- let expected_k = ((cfg.sparsity * n as f32).round() as usize).max(1);
576
- let cpu = SpatialPooler::new(SpatialPoolerConfig::default(), 7);
577
- let mut gpu = SpatialPoolerGpu::from_cpu(&cpu).expect("gpu init");
578
-
579
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
580
- for _ in 0..10 {
581
- let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
582
- let active = gpu.compute(&sdr_u8, false).expect("gpu compute");
583
- assert_eq!(active.len(), expected_k);
584
- // Ensure sorted + unique.
585
- for w in active.windows(2) {
586
- assert!(w[0] < w[1], "duplicate or out-of-order winner indices");
587
- }
588
- }
589
- }
590
-
591
- #[test]
592
- fn fused_launch_plan_uses_cooperative_grid_sync() {
593
- let plan = plan_fused_launch(30, true, 30, None).expect("cooperative supported");
594
- assert_eq!(plan.grid_dim_x, 16);
595
- assert_eq!(plan.cooperative_grid_limit, 30);
596
- }
597
-
598
- #[test]
599
- fn fused_launch_plan_scales_to_big_gpu() {
600
- // H200-like: 132 SMs, high cooperative_grid_limit. Cap still applies.
601
- let plan = plan_fused_launch(132, true, 1000, None).expect("cooperative supported");
602
- assert_eq!(plan.grid_dim_x, 16); // capped by default override
603
- let plan = plan_fused_launch(132, true, 1000, Some(64)).expect("cooperative supported");
604
- assert_eq!(plan.grid_dim_x, 64); // override raises the cap
605
- }
606
-
607
- #[test]
608
- fn fused_launch_plan_refuses_non_cooperative_devices() {
609
- // The slow path was removed. Devices without cooperative launch fail fast.
610
- let err = plan_fused_launch(30, false, 0, None).unwrap_err();
611
- assert!(err.contains("cooperative launch"));
612
- }
613
-
614
- #[test]
615
- fn fused_grid_cap_env_override_is_honored() {
616
- let cfg = SpatialPoolerConfig::default();
617
- let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 5252);
618
- let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
619
- let dev = sp.dev_ref().clone();
620
-
621
- unsafe { std::env::set_var("HTM_FUSED_GRID_CAP", "12"); }
622
- let fused = FusedState::new(
623
- dev.clone(),
624
- cfg.n_columns,
625
- 32usize,
626
- sp.initial_threshold_estimate(),
627
- ).expect("fused init");
628
- unsafe { std::env::remove_var("HTM_FUSED_GRID_CAP"); }
629
-
630
- let sm_count = match dev.attribute(
631
- cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
632
- ) {
633
- Ok(v) => v as u32,
634
- Err(_) => 16u32,
635
- };
636
- let expected = sm_count.max(1).min(12);
637
- assert_eq!(
638
- fused.grid_dim_x,
639
- expected,
640
- "fused grid cap env override ignored: expected min(sm_count, 12) = {expected}, got {}",
641
- fused.grid_dim_x,
642
- );
643
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Parity tests: GPU SP vs CPU SP reference.
2
+ //!
3
+ //! With matching seeds the two should produce bit-identical active-column sets
4
+ //! when `learn=false`, and remain bit-identical over repeated `learn=true`
5
+ //! steps because the Hebbian update is deterministic (no RNG once initialised).
6
+ //!
7
+ //! Run with: cargo test --release --features gpu
8
+
9
+ #![cfg(test)]
10
+ #![cfg(feature = "gpu")]
11
+
12
+ use crate::sp::{SpatialPooler, SpatialPoolerConfig};
13
+ use crate::gpu::sp_gpu::SpatialPoolerGpu;
14
+ use crate::gpu::tm_gpu::TemporalMemoryGpu;
15
+ use crate::gpu::fused::{
16
+ launch_fused, plan_batched_grid_dim, plan_fused_launch, FusedState,
17
+ };
18
+ use cudarc::driver::CudaSlice;
19
+ use rand::{Rng, SeedableRng};
20
+ use rand_xoshiro::Xoshiro256PlusPlus;
21
+
22
+ fn make_sdr(rng: &mut Xoshiro256PlusPlus, bits: usize, sparsity: f32) -> Vec<u8> {
23
+ let on = ((sparsity * bits as f32) as usize).max(1);
24
+ let mut v = vec![0u8; bits];
25
+ let mut placed = 0;
26
+ while placed < on {
27
+ let i = rng.gen_range(0..bits);
28
+ if v[i] == 0 {
29
+ v[i] = 1;
30
+ placed += 1;
31
+ }
32
+ }
33
+ v
34
+ }
35
+
36
+ #[test]
37
+ fn gpu_sp_matches_cpu_no_learn() {
38
+ let cfg = SpatialPoolerConfig::default();
39
+ let bits = cfg.input_bits;
40
+ let mut cpu = SpatialPooler::new(
41
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
42
+ 1234,
43
+ );
44
+ let cpu_for_gpu = SpatialPooler::new(
45
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
46
+ 1234,
47
+ );
48
+ let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu)
49
+ .expect("gpu init (CUDA device available)");
50
+ gpu.set_strict_parity(true);
51
+
52
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
53
+ for step in 0..20 {
54
+ let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
55
+ let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
56
+
57
+ let cpu_active: Vec<u32> = cpu.compute(&sdr_bool, false);
58
+ let gpu_active: Vec<u32> = gpu.compute(&sdr_u8, false).expect("gpu compute");
59
+
60
+ assert_eq!(
61
+ cpu_active, gpu_active,
62
+ "mismatch at step {step}: len cpu={} gpu={}",
63
+ cpu_active.len(), gpu_active.len()
64
+ );
65
+ }
66
+ }
67
+
68
+ #[test]
69
+ fn gpu_sp_matches_cpu_with_learn() {
70
+ let cfg = SpatialPoolerConfig::default();
71
+ let bits = cfg.input_bits;
72
+ let mut cpu = SpatialPooler::new(
73
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
74
+ 5678,
75
+ );
76
+ let cpu_for_gpu = SpatialPooler::new(
77
+ SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
78
+ 5678,
79
+ );
80
+ let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
81
+ gpu.set_strict_parity(true);
82
+
83
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
84
+ for step in 0..50 {
85
+ let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
86
+ let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
87
+
88
+ let cpu_active = cpu.compute(&sdr_bool, true);
89
+ let gpu_active = gpu.compute(&sdr_u8, true).expect("gpu compute");
90
+
91
+ assert_eq!(
92
+ cpu_active, gpu_active,
93
+ "mismatch at step {step} with learning"
94
+ );
95
+ }
96
+ }
97
+
98
+ #[test]
99
+ fn gpu_tm_anomaly_decays_on_repeating_sequence() {
100
+ // End-to-end GPU pipeline: SP feeds TM; repeating SDR sequence should drive
101
+ // anomaly down over time.
102
+ use crate::gpu::HTMRegionGpu; // not pyclass methods; use internal constructor via Rust
103
+ // Easier: replicate the pipeline directly with SP + TM.
104
+
105
+ let cfg = SpatialPoolerConfig::default();
106
+ let bits = cfg.input_bits;
107
+ let n_cols = cfg.n_columns;
108
+ let cells_per_col = 32usize;
109
+
110
+ let cpu_for_gpu = SpatialPooler::new(SpatialPoolerConfig::default(), 314);
111
+ let mut sp = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
112
+ let dev = sp.dev_ref().clone();
113
+ let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col)
114
+ .expect("gpu tm init");
115
+ tm.reset().expect("tm reset");
116
+
117
+ // Build 3 fixed SDRs, feed them in a repeating sequence.
118
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
119
+ let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
120
+ let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
121
+
122
+ // Warm up SP so columns are stable per symbol.
123
+ for _ in 0..100 {
124
+ for s in &seqs {
125
+ let _ = sp.compute(s, true).expect("sp compute");
126
+ }
127
+ }
128
+
129
+ // Build a long input buffer: 100 repetitions of [A,B,C] = 300 steps.
130
+ let repeats = 100usize;
131
+ let t = repeats * 3;
132
+ let mut inputs_flat = vec![0u8; t * bits];
133
+ for r in 0..repeats {
134
+ for (i, s) in seqs.iter().enumerate() {
135
+ let off = (r * 3 + i) * bits;
136
+ inputs_flat[off..off + bits].copy_from_slice(s);
137
+ }
138
+ }
139
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod");
140
+
141
+ let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc cols");
142
+ let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
143
+
144
+ sp.step_batch_with_tm(
145
+ &inputs_dev,
146
+ t,
147
+ bits,
148
+ true,
149
+ &mut cols_dev,
150
+ &mut anom_dev,
151
+ &mut tm,
152
+ ).expect("step_batch_with_tm");
153
+
154
+ let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
155
+ let cols: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
156
+
157
+ // Active column count per step must equal k for every step.
158
+ let k = ((cfg.sparsity * n_cols as f32).round() as usize).max(1);
159
+ for ti in 0..t {
160
+ let step_slice = &cols[ti * n_cols..(ti + 1) * n_cols];
161
+ let n_on = step_slice.iter().filter(|&&b| b != 0).count();
162
+ assert_eq!(n_on, k, "step {ti} has {n_on} active cols, expected {k}");
163
+ }
164
+
165
+ // First repetition: anomaly should be near 1.0 (nothing predicted).
166
+ let early_avg: f32 = anom[3..9].iter().sum::<f32>() / 6.0;
167
+ // Last repetitions: anomaly should be noticeably lower.
168
+ let late_avg: f32 = anom[(t - 9)..t].iter().sum::<f32>() / 9.0;
169
+ eprintln!("gpu tm: early anomaly = {early_avg:.3}, late = {late_avg:.3}");
170
+ assert!(
171
+ late_avg < early_avg,
172
+ "GPU TM should reduce anomaly on repeating sequence: early={early_avg:.3}, late={late_avg:.3}"
173
+ );
174
+ }
175
+
176
+ /// Cluster-sync smoke test: verifies that the fused megakernel (which relies on
177
+ /// hardware `cluster::sync()` / grid-barrier on H100/H200 Hopper) completes
178
+ /// without deadlock when called with real HTM state, and that output shapes are
179
+ /// sane (no NaN / Inf in anomaly scores, active-column count in plausible range).
180
+ ///
181
+ /// This is an *integration* test, not a synthetic micro-benchmark: it exercises
182
+ /// exactly the same `launch_fused` code path used in production, so any
183
+ /// deadlock in the cooperative-grid or DLB barrier would surface here.
184
+ ///
185
+ /// Skips gracefully (with an eprintln) when no GPU is available — the test
186
+ /// binary returns exit-code 0 in that case so CI still passes.
187
+ #[test]
188
+ fn cluster_sync_smoke_test() {
189
+ // Build a tiny HTM region (1024 inputs, 256 columns, 4 cells/column).
190
+ // This keeps VRAM usage minimal while still exercising all kernel paths.
191
+ let input_bits = 1024usize;
192
+ let n_columns = 256usize;
193
+ let cells_per_col = 4usize;
194
+
195
+ // Probe cooperative launch attribute before doing any real work.
196
+ // CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 223 (added in CUDA 11.8 for Hopper).
197
+ // cudarc exposes raw attribute querying; we check cooperative launch (98)
198
+ // as the guard — cluster launch is a superset and not separately probed
199
+ // here since cudarc doesn't expose attribute 223 symbolically yet.
200
+ // On pre-Hopper hardware the DLB barrier path is used instead and the
201
+ // test still validates no deadlock on that path.
202
+
203
+ let make_cfg = || SpatialPoolerConfig {
204
+ input_bits,
205
+ n_columns,
206
+ sparsity: 0.04, // ~10 active cols out of 256
207
+ ..SpatialPoolerConfig::default()
208
+ };
209
+
210
+ let cpu_ref = SpatialPooler::new(make_cfg(), 42);
211
+
212
+ let mut sp = match SpatialPoolerGpu::from_cpu(&cpu_ref) {
213
+ Ok(sp) => sp,
214
+ Err(e) => {
215
+ eprintln!("[cluster_sync_smoke_test] No GPU available ({e:?}) — skipping");
216
+ return;
217
+ }
218
+ };
219
+
220
+ let dev = sp.dev_ref().clone();
221
+
222
+ // Check cooperative launch support; skip with a clear message if absent.
223
+ let cooperative_ok = matches!(
224
+ dev.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
225
+ Ok(v) if v > 0
226
+ );
227
+ if !cooperative_ok {
228
+ eprintln!("[cluster_sync_smoke_test] CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH=0 — DLB path only, still running test");
229
+ // We continue — the DLB path is the production fallback and must not deadlock either.
230
+ }
231
+
232
+ let mut tm = match TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_col) {
233
+ Ok(tm) => tm,
234
+ Err(e) => {
235
+ eprintln!("[cluster_sync_smoke_test] TemporalMemoryGpu::new failed ({e:?}) — skipping");
236
+ return;
237
+ }
238
+ };
239
+ tm.reset().expect("tm reset");
240
+
241
+ let mut fused_st: FusedState = match FusedState::new(
242
+ dev.clone(),
243
+ n_columns,
244
+ cells_per_col,
245
+ sp.initial_threshold_estimate(),
246
+ ) {
247
+ Ok(f) => f,
248
+ Err(e) => {
249
+ eprintln!("[cluster_sync_smoke_test] FusedState::new failed ({e:?}) — skipping");
250
+ return;
251
+ }
252
+ };
253
+ fused_st.reset().expect("fused reset");
254
+
255
+ // Build T=4 timesteps of all-zero input SDRs.
256
+ let t = 4usize;
257
+ let inputs_flat = vec![0u8; t * input_bits];
258
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod inputs");
259
+
260
+ let mut cols_dev = dev.alloc_zeros::<u8>(t * n_columns).expect("alloc cols");
261
+ let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
262
+
263
+ // Execute with a 2-second timeout guard via a thread. If the kernel
264
+ // deadlocks, the parent test process times out and the CI job reports
265
+ // failure — we can't cancel a live CUDA kernel from Rust, but the
266
+ // launch_fused call itself must return within this window on any sane GPU.
267
+ //
268
+ // We run the kernel inline (not in a separate thread) because CUDA contexts
269
+ // are not safely shareable across threads without explicit multi-threading
270
+ // setup. The 2-second bound is enforced implicitly: if the kernel deadlocks,
271
+ // the test binary will hang and the CI timeout (typically 5 min) will kill it.
272
+ // For local dev, the deadlock would be immediately obvious.
273
+
274
+ launch_fused(
275
+ &mut sp,
276
+ &mut tm,
277
+ &mut fused_st,
278
+ &inputs_dev,
279
+ &mut cols_dev,
280
+ &mut anom_dev,
281
+ t,
282
+ input_bits,
283
+ false, // learn=false for determinism
284
+ ).expect("launch_fused (cluster_sync_smoke_test): deadlock or CUDA error");
285
+
286
+ dev.synchronize().expect("device sync after launch_fused");
287
+
288
+ // --- Correctness assertions ---
289
+
290
+ let cols_host: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
291
+ let anom_host: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
292
+
293
+ // Output buffers must be exactly the right size.
294
+ assert_eq!(cols_host.len(), t * n_columns, "cols buffer size mismatch");
295
+ assert_eq!(anom_host.len(), t, "anom buffer size mismatch");
296
+
297
+ // Anomaly scores must be finite (NaN/Inf indicates numerical blow-up).
298
+ for (i, &a) in anom_host.iter().enumerate() {
299
+ assert!(a.is_finite(), "anomaly[{i}] is not finite: {a}");
300
+ assert!(a >= 0.0 && a <= 1.0, "anomaly[{i}] out of [0,1]: {a}");
301
+ }
302
+
303
+ // Active-column count per step: threshold-based inhibition, so 0 is
304
+ // possible on cold start (before thresholds calibrate), but we assert
305
+ // <= n_columns to catch buffer overruns or completely wrong output.
306
+ for ti in 0..t {
307
+ let n_on = cols_host[ti * n_columns..(ti + 1) * n_columns]
308
+ .iter()
309
+ .filter(|&&b| b != 0)
310
+ .count();
311
+ assert!(
312
+ n_on <= n_columns,
313
+ "step {ti}: active columns {n_on} > n_columns {n_columns} (buffer overrun?)"
314
+ );
315
+ }
316
+
317
+ eprintln!(
318
+ "[cluster_sync_smoke_test] PASSED: T={t}, n_cols={n_columns}, \
319
+ input_bits={input_bits}, cooperative_supported={cooperative_ok}, \
320
+ anom={anom_host:?}"
321
+ );
322
+ }
323
+
324
+ /// Parity check: the CAI zero-copy path (`step_many_cuda`) must produce
325
+ /// bit-identical outputs to the numpy H2D/D2H path (`step_batch_with_tm`),
326
+ /// since the kernel pipeline is the same — only the I/O wrapping changes.
327
+ /// We skip the PyO3 CAI dict plumbing here and test the underlying
328
+ /// ManuallyDrop + upgrade_device_ptr pattern directly.
329
+ #[test]
330
+ fn gpu_cuda_vs_numpy_parity() {
331
+ use std::mem::ManuallyDrop;
332
+
333
+ let cfg = SpatialPoolerConfig::default();
334
+ let bits = cfg.input_bits;
335
+ let n_cols = cfg.n_columns;
336
+ let cells_per_col = 32usize;
337
+
338
+ // Build two identical (SP, TM) pairs from the same seed.
339
+ let build = || -> (SpatialPoolerGpu, TemporalMemoryGpu) {
340
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271828);
341
+ let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu init");
342
+ let dev = sp.dev_ref().clone();
343
+ let mut tm = TemporalMemoryGpu::new(dev, n_cols, cells_per_col).expect("tm init");
344
+ tm.reset().expect("tm reset");
345
+ (sp, tm)
346
+ };
347
+
348
+ // Deterministic SDR sequence.
349
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
350
+ let t = 32usize;
351
+ let mut inputs_flat = vec![0u8; t * bits];
352
+ for i in 0..t {
353
+ let sdr = make_sdr(&mut rng, bits, 0.02);
354
+ inputs_flat[i * bits..(i + 1) * bits].copy_from_slice(&sdr);
355
+ }
356
+
357
+ // ---- Path A: owned CudaSlice (numpy-equivalent path) ----
358
+ let (mut sp_a, mut tm_a) = build();
359
+ let dev_a = sp_a.dev_ref().clone();
360
+ let inputs_a: CudaSlice<u8> = dev_a.htod_sync_copy(&inputs_flat).expect("htod");
361
+ let mut cols_a = dev_a.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_a");
362
+ let mut anom_a = dev_a.alloc_zeros::<f32>(t).expect("alloc anom_a");
363
+ sp_a.step_batch_with_tm(&inputs_a, t, bits, false, &mut cols_a, &mut anom_a, &mut tm_a)
364
+ .expect("owned step_batch_with_tm");
365
+ dev_a.synchronize().expect("sync a");
366
+ let cols_a_host: Vec<u8> = dev_a.dtoh_sync_copy(&cols_a).expect("d2h cols_a");
367
+ let anom_a_host: Vec<f32> = dev_a.dtoh_sync_copy(&anom_a).expect("d2h anom_a");
368
+
369
+ // ---- Path B: borrowed device pointers via upgrade_device_ptr ----
370
+ // We allocate fresh owned CudaSlices on a fresh device, then take their
371
+ // raw ptrs and re-wrap as ManuallyDrop borrowed views — mimicking what
372
+ // `step_many_cuda` does with torch-owned CUDA memory.
373
+ let (mut sp_b, mut tm_b) = build();
374
+ let dev_b = sp_b.dev_ref().clone();
375
+ let inputs_b_owned: CudaSlice<u8> = dev_b.htod_sync_copy(&inputs_flat).expect("htod");
376
+ let cols_b_owned = dev_b.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_b");
377
+ let anom_b_owned = dev_b.alloc_zeros::<f32>(t).expect("alloc anom_b");
378
+
379
+ // Extract raw CUdeviceptrs (and leak the owners so their Drop doesn't free).
380
+ let inputs_ptr = inputs_b_owned.leak();
381
+ let cols_ptr = cols_b_owned.leak();
382
+ let anom_ptr = anom_b_owned.leak();
383
+
384
+ // Re-wrap as borrowed views.
385
+ let inputs_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) });
386
+ let mut cols_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) });
387
+ let mut anom_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) });
388
+
389
+ sp_b.step_batch_with_tm(&inputs_b, t, bits, false, &mut cols_b, &mut anom_b, &mut tm_b)
390
+ .expect("borrowed step_batch_with_tm");
391
+ dev_b.synchronize().expect("sync b");
392
+ // `ManuallyDrop` doesn't auto-coerce to `&CudaSlice<T>` for the DevicePtr
393
+ // trait bound on `dtoh_sync_copy`; explicit deref.
394
+ let cols_b_host: Vec<u8> = dev_b.dtoh_sync_copy(&*cols_b).expect("d2h cols_b");
395
+ let anom_b_host: Vec<f32> = dev_b.dtoh_sync_copy(&*anom_b).expect("d2h anom_b");
396
+
397
+ // Re-own so Drop actually frees (we leaked above).
398
+ let _inputs_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) };
399
+ let _cols_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) };
400
+ let _anom_owned_again = unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) };
401
+
402
+ assert_eq!(cols_a_host, cols_b_host, "active-column mask diverges between numpy and CAI paths");
403
+ assert_eq!(anom_a_host.len(), anom_b_host.len());
404
+ for (i, (a, b)) in anom_a_host.iter().zip(anom_b_host.iter()).enumerate() {
405
+ // Anomaly is a pure division of integer counts — bit-exact expected.
406
+ assert!((a - b).abs() < 1e-7, "anomaly mismatch at step {i}: a={a} b={b}");
407
+ }
408
+ }
409
+
410
+ /// Fused kernel: threshold activation should converge to near target sparsity
411
+ /// after a short warmup. Acceptance: mean activation rate per step lands in
412
+ /// [0.3*target, 2.5*target] after 500-step warmup. Because the threshold
413
+ /// starts conservative (=2.0) and the per-column adaptation rate is slow
414
+ /// (0.001), we allow a generous band — the test asserts directional
415
+ /// convergence toward the target, not tight matching.
416
+ #[test]
417
+ fn gpu_threshold_converges_to_sparsity() {
418
+ let cfg = SpatialPoolerConfig::default();
419
+ let bits = cfg.input_bits;
420
+ let n_cols = cfg.n_columns;
421
+ let cells_per_col = 32usize;
422
+ let target = cfg.sparsity; // 0.02 = 40 cols expected
423
+
424
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 111);
425
+ let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
426
+ let dev = sp.dev_ref().clone();
427
+ let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
428
+ let mut fused = FusedState::new(
429
+ dev.clone(),
430
+ n_cols,
431
+ cells_per_col,
432
+ sp.initial_threshold_estimate(),
433
+ ).expect("fused init");
434
+ tm.reset().expect("tm reset");
435
+ fused.reset().expect("fused reset");
436
+
437
+ // Warmup: 1000 random 2%-sparse SDRs.
438
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
439
+ let t_warm = 1000usize;
440
+ let mut inputs = vec![0u8; t_warm * bits];
441
+ for ti in 0..t_warm {
442
+ let sdr = make_sdr(&mut rng, bits, 0.02);
443
+ inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
444
+ }
445
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod");
446
+ let mut cols_dev = dev.alloc_zeros::<u8>(t_warm * n_cols).expect("alloc cols");
447
+ let mut anom_dev = dev.alloc_zeros::<f32>(t_warm).expect("alloc anom");
448
+ launch_fused(
449
+ &mut sp, &mut tm, &mut fused,
450
+ &inputs_dev, &mut cols_dev, &mut anom_dev,
451
+ t_warm, bits, true,
452
+ ).expect("warmup launch");
453
+ dev.synchronize().expect("sync");
454
+
455
+ // Measurement pass: another 200 steps, measure mean activation.
456
+ let t_meas = 200usize;
457
+ let mut meas_inputs = vec![0u8; t_meas * bits];
458
+ for ti in 0..t_meas {
459
+ let sdr = make_sdr(&mut rng, bits, 0.02);
460
+ meas_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
461
+ }
462
+ let meas_dev: CudaSlice<u8> = dev.htod_sync_copy(&meas_inputs).expect("htod meas");
463
+ let mut meas_cols = dev.alloc_zeros::<u8>(t_meas * n_cols).expect("alloc meas cols");
464
+ let mut meas_anom = dev.alloc_zeros::<f32>(t_meas).expect("alloc meas anom");
465
+ launch_fused(
466
+ &mut sp, &mut tm, &mut fused,
467
+ &meas_dev, &mut meas_cols, &mut meas_anom,
468
+ t_meas, bits, true,
469
+ ).expect("meas launch");
470
+ dev.synchronize().expect("sync meas");
471
+
472
+ let cols_host: Vec<u8> = dev.dtoh_sync_copy(&meas_cols).expect("d2h");
473
+ let mut step_counts = Vec::with_capacity(t_meas);
474
+ for ti in 0..t_meas {
475
+ let n_on = cols_host[ti*n_cols..(ti+1)*n_cols]
476
+ .iter().filter(|&&b| b != 0).count();
477
+ step_counts.push(n_on);
478
+ }
479
+ let mean_active: f64 = step_counts.iter().map(|&c| c as f64).sum::<f64>()
480
+ / (t_meas as f64);
481
+ let target_active = target as f64 * n_cols as f64;
482
+ eprintln!(
483
+ "threshold-activation convergence: mean_active/step = {mean_active:.1} \
484
+ (target = {target_active:.1})"
485
+ );
486
+ // Very generous band — we just want to confirm the threshold loop is
487
+ // functioning (not diverged to 0 or to all-active).
488
+ assert!(
489
+ mean_active >= 0.25 * target_active && mean_active <= 4.0 * target_active,
490
+ "mean active {mean_active:.1} outside [0.25x, 4x] of target {target_active:.1}"
491
+ );
492
+ }
493
+
494
+ /// Fused kernel: TM should learn a repeating sequence — anomaly decays.
495
+ #[test]
496
+ fn gpu_fused_tm_anomaly_decays_on_repeating_sequence() {
497
+ let cfg = SpatialPoolerConfig::default();
498
+ let bits = cfg.input_bits;
499
+ let n_cols = cfg.n_columns;
500
+ let cells_per_col = 32usize;
501
+
502
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271);
503
+ let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
504
+ let dev = sp.dev_ref().clone();
505
+ let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
506
+ let mut fused = FusedState::new(
507
+ dev.clone(),
508
+ n_cols,
509
+ cells_per_col,
510
+ sp.initial_threshold_estimate(),
511
+ ).expect("fused init");
512
+ tm.reset().expect("tm reset");
513
+ fused.reset().expect("fused reset");
514
+
515
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
516
+ let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
517
+ let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
518
+
519
+ // Warmup SP threshold calibration with random SDRs first.
520
+ let warm = 300usize;
521
+ let mut warm_inputs = vec![0u8; warm * bits];
522
+ for ti in 0..warm {
523
+ let sdr = make_sdr(&mut rng, bits, 0.02);
524
+ warm_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
525
+ }
526
+ let warm_dev: CudaSlice<u8> = dev.htod_sync_copy(&warm_inputs).expect("htod warm");
527
+ let mut warm_cols = dev.alloc_zeros::<u8>(warm * n_cols).expect("alloc warm cols");
528
+ let mut warm_anom = dev.alloc_zeros::<f32>(warm).expect("alloc warm anom");
529
+ launch_fused(
530
+ &mut sp, &mut tm, &mut fused,
531
+ &warm_dev, &mut warm_cols, &mut warm_anom,
532
+ warm, bits, true,
533
+ ).expect("warm launch");
534
+ dev.synchronize().expect("sync warm");
535
+
536
+ // Feed repeating A,B,C sequence for 100 reps.
537
+ let repeats = 100usize;
538
+ let t = repeats * 3;
539
+ let mut inputs = vec![0u8; t * bits];
540
+ for r in 0..repeats {
541
+ for (i, s) in seqs.iter().enumerate() {
542
+ let off = (r*3 + i) * bits;
543
+ inputs[off..off+bits].copy_from_slice(s);
544
+ }
545
+ }
546
+ let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod rep");
547
+ let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc rep cols");
548
+ let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc rep anom");
549
+ launch_fused(
550
+ &mut sp, &mut tm, &mut fused,
551
+ &inputs_dev, &mut cols_dev, &mut anom_dev,
552
+ t, bits, true,
553
+ ).expect("rep launch");
554
+ dev.synchronize().expect("sync rep");
555
+
556
+ let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
557
+ let early_avg: f32 = anom[3..12].iter().sum::<f32>() / 9.0;
558
+ let late_avg: f32 = anom[(t-9)..t].iter().sum::<f32>() / 9.0;
559
+ eprintln!("fused TM anomaly: early={early_avg:.3} late={late_avg:.3}");
560
+ assert!(
561
+ late_avg < early_avg,
562
+ "anomaly must decay: early={early_avg:.3} late={late_avg:.3}"
563
+ );
564
+ assert!(
565
+ late_avg < 0.5,
566
+ "late anomaly must be < 0.5 (got {late_avg:.3})"
567
+ );
568
+ }
569
+
570
+ #[test]
571
+ fn gpu_sp_yields_k_winners() {
572
+ let cfg = SpatialPoolerConfig::default();
573
+ let bits = cfg.input_bits;
574
+ let n = cfg.n_columns;
575
+ let expected_k = ((cfg.sparsity * n as f32).round() as usize).max(1);
576
+ let cpu = SpatialPooler::new(SpatialPoolerConfig::default(), 7);
577
+ let mut gpu = SpatialPoolerGpu::from_cpu(&cpu).expect("gpu init");
578
+
579
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
580
+ for _ in 0..10 {
581
+ let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
582
+ let active = gpu.compute(&sdr_u8, false).expect("gpu compute");
583
+ assert_eq!(active.len(), expected_k);
584
+ // Ensure sorted + unique.
585
+ for w in active.windows(2) {
586
+ assert!(w[0] < w[1], "duplicate or out-of-order winner indices");
587
+ }
588
+ }
589
+ }
590
+
591
+ #[test]
592
+ fn fused_launch_plan_uses_cooperative_grid_sync() {
593
+ let plan = plan_fused_launch(30, true, 30, None).expect("cooperative supported");
594
+ assert_eq!(plan.grid_dim_x, 16);
595
+ assert_eq!(plan.cooperative_grid_limit, 30);
596
+ }
597
+
598
+ #[test]
599
+ fn fused_launch_plan_scales_to_big_gpu() {
600
+ // H200-like: 132 SMs, high cooperative_grid_limit. Cap still applies.
601
+ let plan = plan_fused_launch(132, true, 1000, None).expect("cooperative supported");
602
+ assert_eq!(plan.grid_dim_x, 16); // capped by default override
603
+ let plan = plan_fused_launch(132, true, 1000, Some(64)).expect("cooperative supported");
604
+ assert_eq!(plan.grid_dim_x, 64); // override raises the cap
605
+ }
606
+
607
+ #[test]
608
+ fn fused_launch_plan_refuses_non_cooperative_devices() {
609
+ // The slow path was removed. Devices without cooperative launch fail fast.
610
+ let err = plan_fused_launch(30, false, 0, None).unwrap_err();
611
+ assert!(err.contains("cooperative launch"));
612
+ }
613
+
614
+ #[test]
615
+ fn fused_grid_cap_env_override_is_honored() {
616
+ let cfg = SpatialPoolerConfig::default();
617
+ let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 5252);
618
+ let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
619
+ let dev = sp.dev_ref().clone();
620
+
621
+ unsafe { std::env::set_var("HTM_FUSED_GRID_CAP", "12"); }
622
+ let fused = FusedState::new(
623
+ dev.clone(),
624
+ cfg.n_columns,
625
+ 32usize,
626
+ sp.initial_threshold_estimate(),
627
+ ).expect("fused init");
628
+ unsafe { std::env::remove_var("HTM_FUSED_GRID_CAP"); }
629
+
630
+ let sm_count = match dev.attribute(
631
+ cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
632
+ ) {
633
+ Ok(v) => v as u32,
634
+ Err(_) => 16u32,
635
+ };
636
+ let expected = sm_count.max(1).min(12);
637
+ assert_eq!(
638
+ fused.grid_dim_x,
639
+ expected,
640
+ "fused grid cap env override ignored: expected min(sm_count, 12) = {expected}, got {}",
641
+ fused.grid_dim_x,
642
+ );
643
+ }
644
+
645
+ #[test]
646
+ fn batched_grid_plan_clamps_a10g_batch32_under_cooperative_limit() {
647
+ // A10G observed in HF Jobs: cooperative_grid_limit=400, B=32.
648
+ // grid_x=16 requests 512 cooperative blocks and fails; clamp to 12.
649
+ let grid_x = plan_batched_grid_dim(16, 400, 32, false).expect("fits after clamp");
650
+ assert_eq!(grid_x, 12);
651
+ }
652
+
653
+ #[test]
654
+ fn batched_grid_plan_reports_oversized_batch() {
655
+ let err = plan_batched_grid_dim(16, 31, 32, false).unwrap_err();
656
+ assert!(err.contains("COOPERATIVE_LAUNCH_TOO_LARGE"));
657
+ }
658
+
659
+ #[test]
660
+ fn batched_grid_plan_does_not_clamp_cluster_launches() {
661
+ let grid_x = plan_batched_grid_dim(16, 31, 32, true).expect("cluster path bypasses cooperative limit");
662
+ assert_eq!(grid_x, 16);
663
+ }
overlay/htm_rust/src/lib.rs CHANGED
@@ -1,198 +1,198 @@
1
- //! pyo3 bindings for HTMRegion (Numenta BAMI-spec HTM).
2
- //!
3
- //! Exposed class:
4
- //! HTMRegion(input_bits, n_columns, cells_per_column, seed) -> HTMRegion
5
- //! .step(input_sdr: np.ndarray[bool; input_bits], learn: bool = True)
6
- //! -> (active_columns: np.ndarray[bool; n_columns],
7
- //! active_cells: np.ndarray[bool; n_columns*cells_per_column],
8
- //! predicted_cells:np.ndarray[bool; n_columns*cells_per_column],
9
- //! anomaly: float)
10
- //! .reset()
11
- //! .n_columns -> int
12
- //! .cells_per_column -> int
13
- //! .input_bits -> int
14
- //!
15
- //! GIL is dropped during the heavy compute via `py.allow_threads(...)` so the
16
- //! region is effectively `Send` for Python-side threading.
17
-
18
- // pyo3 0.22 `#[pymethods]` expansion inserts an implicit `.into()` on the
19
- // returned `Result` to normalise the error type, which clippy reports as
20
- // `useless_conversion` when our methods already return `PyErr`. The emitted
21
- // code sits outside the user-written impl, so item-level allows don't reach
22
- // it; the module-wide allow is the documented workaround.
23
- #![allow(clippy::useless_conversion)]
24
-
25
- mod region;
26
- mod sp;
27
- mod tm;
28
-
29
- #[cfg(feature = "gpu")]
30
- mod gpu;
31
-
32
- use numpy::{
33
- IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2,
34
- PyUntypedArrayMethods,
35
- };
36
- use pyo3::prelude::*;
37
-
38
- use crate::region::HTMRegionCore;
39
-
40
- /// Result of one HTM step: (active_columns, active_cells, predicted_cells, anomaly).
41
- type StepOutput<'py> = (
42
- Bound<'py, PyArray1<bool>>,
43
- Bound<'py, PyArray1<bool>>,
44
- Bound<'py, PyArray1<bool>>,
45
- f32,
46
- );
47
-
48
- #[pyclass(module = "htm_rust")]
49
- pub struct HTMRegion {
50
- core: HTMRegionCore,
51
- }
52
-
53
- #[pymethods]
54
- impl HTMRegion {
55
- /// Create a new HTM region.
56
- ///
57
- /// Args:
58
- /// input_bits: length of binary input SDR
59
- /// n_columns: number of mini-columns in the SP (e.g. 2048)
60
- /// cells_per_column: cells per column in the TM (e.g. 32)
61
- /// seed: RNG seed for reproducibility
62
- #[new]
63
- #[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
64
- fn new(
65
- input_bits: usize,
66
- n_columns: usize,
67
- cells_per_column: usize,
68
- seed: u64,
69
- ) -> PyResult<Self> {
70
- if input_bits == 0 {
71
- return Err(pyo3::exceptions::PyValueError::new_err(
72
- "input_bits must be > 0",
73
- ));
74
- }
75
- if n_columns == 0 {
76
- return Err(pyo3::exceptions::PyValueError::new_err(
77
- "n_columns must be > 0",
78
- ));
79
- }
80
- if cells_per_column == 0 {
81
- return Err(pyo3::exceptions::PyValueError::new_err(
82
- "cells_per_column must be > 0",
83
- ));
84
- }
85
- Ok(Self {
86
- core: HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed),
87
- })
88
- }
89
-
90
- #[getter]
91
- fn input_bits(&self) -> usize { self.core.sp.cfg.input_bits }
92
-
93
- #[getter]
94
- fn n_columns(&self) -> usize { self.core.sp.cfg.n_columns }
95
-
96
- #[getter]
97
- fn cells_per_column(&self) -> usize { self.core.tm.cfg.cells_per_column }
98
-
99
- /// Process one timestep.
100
- ///
101
- /// Args:
102
- /// input_sdr: 1-D numpy boolean array of length `input_bits`.
103
- /// learn: if True, update SP permanences and TM synapses.
104
- ///
105
- /// Returns:
106
- /// (active_columns, active_cells, predicted_cells, anomaly)
107
- #[pyo3(signature = (input_sdr, learn=true))]
108
- fn step<'py>(
109
- &mut self,
110
- py: Python<'py>,
111
- input_sdr: PyReadonlyArray1<'py, bool>,
112
- learn: bool,
113
- ) -> PyResult<StepOutput<'py>> {
114
- let expected = self.core.sp.cfg.input_bits;
115
- let slice = input_sdr.as_slice()?;
116
- let got = slice.len();
117
- if got != expected {
118
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
119
- "input_sdr length {got} != expected input_bits {expected}",
120
- )));
121
- }
122
-
123
- // Copy input to an owned Vec so we can drop the GIL.
124
- let input_vec: Vec<bool> = slice.to_vec();
125
-
126
- let (active_cols, active_cells, predicted_cells, anomaly) =
127
- py.allow_threads(|| self.core.step(&input_vec, learn));
128
-
129
- let a: Bound<'py, PyArray1<bool>> = active_cols.into_pyarray_bound(py);
130
- let c: Bound<'py, PyArray1<bool>> = active_cells.into_pyarray_bound(py);
131
- let p: Bound<'py, PyArray1<bool>> = predicted_cells.into_pyarray_bound(py);
132
- Ok((a, c, p, anomaly))
133
- }
134
-
135
- /// Clear TM predictive state. Does NOT unlearn synapses.
136
- fn reset(&mut self) { self.core.reset(); }
137
-
138
- /// Process T timesteps from a `(T, input_bits)` bool ndarray.
139
- ///
140
- /// Returns:
141
- /// cols: (T, n_columns) float32 0/1 active-column mask
142
- /// anom: (T,) float32 anomaly scores
143
- ///
144
- /// Single GIL release for the whole pass, avoiding T × Python-call overhead.
145
- #[pyo3(signature = (inputs, learn=true))]
146
- fn step_many<'py>(
147
- &mut self,
148
- py: Python<'py>,
149
- inputs: PyReadonlyArray2<'py, bool>,
150
- learn: bool,
151
- ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
152
- let shape = inputs.shape();
153
- if shape.len() != 2 {
154
- return Err(pyo3::exceptions::PyValueError::new_err(
155
- "inputs must be 2-D (T, input_bits)",
156
- ));
157
- }
158
- let t = shape[0];
159
- let bits = shape[1];
160
- let expected = self.core.sp.cfg.input_bits;
161
- if bits != expected {
162
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
163
- "inputs last dim {bits} != expected input_bits {expected}",
164
- )));
165
- }
166
- let slice = inputs.as_slice()?;
167
- let n_cols = self.core.sp.cfg.n_columns;
168
-
169
- // Own the input buffer so we can drop the GIL.
170
- let input_vec: Vec<bool> = slice.to_vec();
171
-
172
- let (cols_u8, anom) =
173
- py.allow_threads(|| self.core.step_many(&input_vec, bits, t, learn));
174
-
175
- // Convert u8 mask to f32 for direct numpy consumption.
176
- let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
177
-
178
- // Build (T, n_cols) and (T,) arrays.
179
- let cols_arr =
180
- numpy::PyArray1::from_vec_bound(py, cols_f32)
181
- .reshape([t, n_cols])
182
- .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
183
- let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
184
- Ok((cols_arr, anom_arr))
185
- }
186
- }
187
-
188
- /// Python module entry point.
189
- #[pymodule]
190
- fn htm_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
191
- m.add_class::<HTMRegion>()?;
192
- #[cfg(feature = "gpu")]
193
- {
194
- gpu::register(m)?;
195
- }
196
- m.add("__version__", env!("CARGO_PKG_VERSION"))?;
197
- Ok(())
198
- }
 
1
+ //! pyo3 bindings for HTMRegion (Numenta BAMI-spec HTM).
2
+ //!
3
+ //! Exposed class:
4
+ //! HTMRegion(input_bits, n_columns, cells_per_column, seed) -> HTMRegion
5
+ //! .step(input_sdr: np.ndarray[bool; input_bits], learn: bool = True)
6
+ //! -> (active_columns: np.ndarray[bool; n_columns],
7
+ //! active_cells: np.ndarray[bool; n_columns*cells_per_column],
8
+ //! predicted_cells:np.ndarray[bool; n_columns*cells_per_column],
9
+ //! anomaly: float)
10
+ //! .reset()
11
+ //! .n_columns -> int
12
+ //! .cells_per_column -> int
13
+ //! .input_bits -> int
14
+ //!
15
+ //! GIL is dropped during the heavy compute via `py.allow_threads(...)` so the
16
+ //! region is effectively `Send` for Python-side threading.
17
+
18
+ // pyo3 0.22 `#[pymethods]` expansion inserts an implicit `.into()` on the
19
+ // returned `Result` to normalise the error type, which clippy reports as
20
+ // `useless_conversion` when our methods already return `PyErr`. The emitted
21
+ // code sits outside the user-written impl, so item-level allows don't reach
22
+ // it; the module-wide allow is the documented workaround.
23
+ #![allow(clippy::useless_conversion)]
24
+
25
+ mod region;
26
+ mod sp;
27
+ mod tm;
28
+
29
+ #[cfg(feature = "gpu")]
30
+ mod gpu;
31
+
32
+ use numpy::{
33
+ IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2,
34
+ PyUntypedArrayMethods,
35
+ };
36
+ use pyo3::prelude::*;
37
+
38
+ use crate::region::HTMRegionCore;
39
+
40
+ /// Result of one HTM step: (active_columns, active_cells, predicted_cells, anomaly).
41
+ type StepOutput<'py> = (
42
+ Bound<'py, PyArray1<bool>>,
43
+ Bound<'py, PyArray1<bool>>,
44
+ Bound<'py, PyArray1<bool>>,
45
+ f32,
46
+ );
47
+
48
+ #[pyclass(module = "htm_rust")]
49
+ pub struct HTMRegion {
50
+ core: HTMRegionCore,
51
+ }
52
+
53
+ #[pymethods]
54
+ impl HTMRegion {
55
+ /// Create a new HTM region.
56
+ ///
57
+ /// Args:
58
+ /// input_bits: length of binary input SDR
59
+ /// n_columns: number of mini-columns in the SP (e.g. 2048)
60
+ /// cells_per_column: cells per column in the TM (e.g. 32)
61
+ /// seed: RNG seed for reproducibility
62
+ #[new]
63
+ #[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
64
+ fn new(
65
+ input_bits: usize,
66
+ n_columns: usize,
67
+ cells_per_column: usize,
68
+ seed: u64,
69
+ ) -> PyResult<Self> {
70
+ if input_bits == 0 {
71
+ return Err(pyo3::exceptions::PyValueError::new_err(
72
+ "input_bits must be > 0",
73
+ ));
74
+ }
75
+ if n_columns == 0 {
76
+ return Err(pyo3::exceptions::PyValueError::new_err(
77
+ "n_columns must be > 0",
78
+ ));
79
+ }
80
+ if cells_per_column == 0 {
81
+ return Err(pyo3::exceptions::PyValueError::new_err(
82
+ "cells_per_column must be > 0",
83
+ ));
84
+ }
85
+ Ok(Self {
86
+ core: HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed),
87
+ })
88
+ }
89
+
90
+ #[getter]
91
+ fn input_bits(&self) -> usize { self.core.sp.cfg.input_bits }
92
+
93
+ #[getter]
94
+ fn n_columns(&self) -> usize { self.core.sp.cfg.n_columns }
95
+
96
+ #[getter]
97
+ fn cells_per_column(&self) -> usize { self.core.tm.cfg.cells_per_column }
98
+
99
+ /// Process one timestep.
100
+ ///
101
+ /// Args:
102
+ /// input_sdr: 1-D numpy boolean array of length `input_bits`.
103
+ /// learn: if True, update SP permanences and TM synapses.
104
+ ///
105
+ /// Returns:
106
+ /// (active_columns, active_cells, predicted_cells, anomaly)
107
+ #[pyo3(signature = (input_sdr, learn=true))]
108
+ fn step<'py>(
109
+ &mut self,
110
+ py: Python<'py>,
111
+ input_sdr: PyReadonlyArray1<'py, bool>,
112
+ learn: bool,
113
+ ) -> PyResult<StepOutput<'py>> {
114
+ let expected = self.core.sp.cfg.input_bits;
115
+ let slice = input_sdr.as_slice()?;
116
+ let got = slice.len();
117
+ if got != expected {
118
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
119
+ "input_sdr length {got} != expected input_bits {expected}",
120
+ )));
121
+ }
122
+
123
+ // Copy input to an owned Vec so we can drop the GIL.
124
+ let input_vec: Vec<bool> = slice.to_vec();
125
+
126
+ let (active_cols, active_cells, predicted_cells, anomaly) =
127
+ py.allow_threads(|| self.core.step(&input_vec, learn));
128
+
129
+ let a: Bound<'py, PyArray1<bool>> = active_cols.into_pyarray_bound(py);
130
+ let c: Bound<'py, PyArray1<bool>> = active_cells.into_pyarray_bound(py);
131
+ let p: Bound<'py, PyArray1<bool>> = predicted_cells.into_pyarray_bound(py);
132
+ Ok((a, c, p, anomaly))
133
+ }
134
+
135
+ /// Clear TM predictive state. Does NOT unlearn synapses.
136
+ fn reset(&mut self) { self.core.reset(); }
137
+
138
+ /// Process T timesteps from a `(T, input_bits)` bool ndarray.
139
+ ///
140
+ /// Returns:
141
+ /// cols: (T, n_columns) float32 0/1 active-column mask
142
+ /// anom: (T,) float32 anomaly scores
143
+ ///
144
+ /// Single GIL release for the whole pass, avoiding T × Python-call overhead.
145
+ #[pyo3(signature = (inputs, learn=true))]
146
+ fn step_many<'py>(
147
+ &mut self,
148
+ py: Python<'py>,
149
+ inputs: PyReadonlyArray2<'py, bool>,
150
+ learn: bool,
151
+ ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
152
+ let shape = inputs.shape();
153
+ if shape.len() != 2 {
154
+ return Err(pyo3::exceptions::PyValueError::new_err(
155
+ "inputs must be 2-D (T, input_bits)",
156
+ ));
157
+ }
158
+ let t = shape[0];
159
+ let bits = shape[1];
160
+ let expected = self.core.sp.cfg.input_bits;
161
+ if bits != expected {
162
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
163
+ "inputs last dim {bits} != expected input_bits {expected}",
164
+ )));
165
+ }
166
+ let slice = inputs.as_slice()?;
167
+ let n_cols = self.core.sp.cfg.n_columns;
168
+
169
+ // Own the input buffer so we can drop the GIL.
170
+ let input_vec: Vec<bool> = slice.to_vec();
171
+
172
+ let (cols_u8, anom) =
173
+ py.allow_threads(|| self.core.step_many(&input_vec, bits, t, learn));
174
+
175
+ // Convert u8 mask to f32 for direct numpy consumption.
176
+ let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
177
+
178
+ // Build (T, n_cols) and (T,) arrays.
179
+ let cols_arr =
180
+ numpy::PyArray1::from_vec_bound(py, cols_f32)
181
+ .reshape([t, n_cols])
182
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
183
+ let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
184
+ Ok((cols_arr, anom_arr))
185
+ }
186
+ }
187
+
188
+ /// Python module entry point.
189
+ #[pymodule]
190
+ fn htm_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
191
+ m.add_class::<HTMRegion>()?;
192
+ #[cfg(feature = "gpu")]
193
+ {
194
+ gpu::register(m)?;
195
+ }
196
+ m.add("__version__", env!("CARGO_PKG_VERSION"))?;
197
+ Ok(())
198
+ }
overlay/htm_rust/src/region.rs CHANGED
@@ -1,94 +1,94 @@
1
- //! HTMRegion: compose SpatialPooler + TemporalMemory into a single step().
2
-
3
- use crate::sp::{SpatialPooler, SpatialPoolerConfig};
4
- use crate::tm::{TemporalMemory, TemporalMemoryConfig};
5
-
6
- pub struct HTMRegionCore {
7
- pub sp: SpatialPooler,
8
- pub tm: TemporalMemory,
9
- }
10
-
11
- impl HTMRegionCore {
12
- pub fn new(
13
- input_bits: usize,
14
- n_columns: usize,
15
- cells_per_column: usize,
16
- seed: u64,
17
- ) -> Self {
18
- let defaults = SpatialPoolerConfig::default();
19
- let sp_cfg = SpatialPoolerConfig {
20
- input_bits,
21
- n_columns,
22
- // Scale potential_radius to at most the input size.
23
- potential_radius: defaults.potential_radius.min(input_bits),
24
- ..defaults
25
- };
26
-
27
- let tm_cfg = TemporalMemoryConfig {
28
- n_columns,
29
- cells_per_column,
30
- ..TemporalMemoryConfig::default()
31
- };
32
-
33
- Self {
34
- sp: SpatialPooler::new(sp_cfg, seed),
35
- tm: TemporalMemory::new(tm_cfg, seed.wrapping_add(0x9E3779B97F4A7C15)),
36
- }
37
- }
38
-
39
- /// Process one timestep. Returns (active_columns_mask,
40
- /// active_cells_mask, predicted_cells_mask, anomaly).
41
- pub fn step(
42
- &mut self,
43
- input_sdr: &[bool],
44
- learn: bool,
45
- ) -> (Vec<bool>, Vec<bool>, Vec<bool>, f32) {
46
- let active_cols = self.sp.compute(input_sdr, learn);
47
-
48
- let mut active_cols_mask = vec![false; self.sp.cfg.n_columns];
49
- for &c in &active_cols {
50
- active_cols_mask[c as usize] = true;
51
- }
52
-
53
- let anomaly = self.tm.compute(&active_cols, learn);
54
-
55
- // active_cells and predictive_cells are stored as Vec<bool> already.
56
- let active_cells_mask = self.tm.active_cells.clone();
57
- let predicted_cells_mask = self.tm.predictive_cells.clone();
58
-
59
- (active_cols_mask, active_cells_mask, predicted_cells_mask, anomaly)
60
- }
61
-
62
- pub fn reset(&mut self) {
63
- self.tm.reset();
64
- }
65
-
66
- /// Process T timesteps in one call. Returns flat `(T*n_columns)` active-column
67
- /// mask (u8 0/1) and `(T,)` anomaly scores.
68
- ///
69
- /// Amortises the per-step Python round-trip for training: one GIL release,
70
- /// one copy-out. Used by `HTMLayer.step_many`.
71
- pub fn step_many(
72
- &mut self,
73
- inputs_flat: &[bool],
74
- input_bits: usize,
75
- t: usize,
76
- learn: bool,
77
- ) -> (Vec<u8>, Vec<f32>) {
78
- let n_cols = self.sp.cfg.n_columns;
79
- debug_assert_eq!(inputs_flat.len(), t * input_bits);
80
- let mut cols = vec![0u8; t * n_cols];
81
- let mut anom = vec![0f32; t];
82
- for ti in 0..t {
83
- let off = ti * input_bits;
84
- let input = &inputs_flat[off..off + input_bits];
85
- let active_cols = self.sp.compute(input, learn);
86
- let co = ti * n_cols;
87
- for &c in &active_cols {
88
- cols[co + c as usize] = 1;
89
- }
90
- anom[ti] = self.tm.compute(&active_cols, learn);
91
- }
92
- (cols, anom)
93
- }
94
- }
 
1
+ //! HTMRegion: compose SpatialPooler + TemporalMemory into a single step().
2
+
3
+ use crate::sp::{SpatialPooler, SpatialPoolerConfig};
4
+ use crate::tm::{TemporalMemory, TemporalMemoryConfig};
5
+
6
+ pub struct HTMRegionCore {
7
+ pub sp: SpatialPooler,
8
+ pub tm: TemporalMemory,
9
+ }
10
+
11
+ impl HTMRegionCore {
12
+ pub fn new(
13
+ input_bits: usize,
14
+ n_columns: usize,
15
+ cells_per_column: usize,
16
+ seed: u64,
17
+ ) -> Self {
18
+ let defaults = SpatialPoolerConfig::default();
19
+ let sp_cfg = SpatialPoolerConfig {
20
+ input_bits,
21
+ n_columns,
22
+ // Scale potential_radius to at most the input size.
23
+ potential_radius: defaults.potential_radius.min(input_bits),
24
+ ..defaults
25
+ };
26
+
27
+ let tm_cfg = TemporalMemoryConfig {
28
+ n_columns,
29
+ cells_per_column,
30
+ ..TemporalMemoryConfig::default()
31
+ };
32
+
33
+ Self {
34
+ sp: SpatialPooler::new(sp_cfg, seed),
35
+ tm: TemporalMemory::new(tm_cfg, seed.wrapping_add(0x9E3779B97F4A7C15)),
36
+ }
37
+ }
38
+
39
+ /// Process one timestep. Returns (active_columns_mask,
40
+ /// active_cells_mask, predicted_cells_mask, anomaly).
41
+ pub fn step(
42
+ &mut self,
43
+ input_sdr: &[bool],
44
+ learn: bool,
45
+ ) -> (Vec<bool>, Vec<bool>, Vec<bool>, f32) {
46
+ let active_cols = self.sp.compute(input_sdr, learn);
47
+
48
+ let mut active_cols_mask = vec![false; self.sp.cfg.n_columns];
49
+ for &c in &active_cols {
50
+ active_cols_mask[c as usize] = true;
51
+ }
52
+
53
+ let anomaly = self.tm.compute(&active_cols, learn);
54
+
55
+ // active_cells and predictive_cells are stored as Vec<bool> already.
56
+ let active_cells_mask = self.tm.active_cells.clone();
57
+ let predicted_cells_mask = self.tm.predictive_cells.clone();
58
+
59
+ (active_cols_mask, active_cells_mask, predicted_cells_mask, anomaly)
60
+ }
61
+
62
+ pub fn reset(&mut self) {
63
+ self.tm.reset();
64
+ }
65
+
66
+ /// Process T timesteps in one call. Returns flat `(T*n_columns)` active-column
67
+ /// mask (u8 0/1) and `(T,)` anomaly scores.
68
+ ///
69
+ /// Amortises the per-step Python round-trip for training: one GIL release,
70
+ /// one copy-out. Used by `HTMLayer.step_many`.
71
+ pub fn step_many(
72
+ &mut self,
73
+ inputs_flat: &[bool],
74
+ input_bits: usize,
75
+ t: usize,
76
+ learn: bool,
77
+ ) -> (Vec<u8>, Vec<f32>) {
78
+ let n_cols = self.sp.cfg.n_columns;
79
+ debug_assert_eq!(inputs_flat.len(), t * input_bits);
80
+ let mut cols = vec![0u8; t * n_cols];
81
+ let mut anom = vec![0f32; t];
82
+ for ti in 0..t {
83
+ let off = ti * input_bits;
84
+ let input = &inputs_flat[off..off + input_bits];
85
+ let active_cols = self.sp.compute(input, learn);
86
+ let co = ti * n_cols;
87
+ for &c in &active_cols {
88
+ cols[co + c as usize] = 1;
89
+ }
90
+ anom[ti] = self.tm.compute(&active_cols, learn);
91
+ }
92
+ (cols, anom)
93
+ }
94
+ }
overlay/htm_rust/src/sp.rs CHANGED
@@ -1,302 +1,302 @@
1
- //! Numenta BAMI-spec Spatial Pooler.
2
- //!
3
- //! Implements:
4
- //! - 2048 (configurable) mini-columns with proximal dendrites
5
- //! - `potential_synapses` (default 40) synapses per column sampled from
6
- //! `potential_radius` (default 1024) random input bits
7
- //! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5
8
- //! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008
9
- //! - Global k-WTA inhibition (top `sparsity` fraction of columns)
10
- //! - Boost factor with exponential duty-cycle tracking (Numenta formula)
11
- //!
12
- //! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017).
13
-
14
- use rand::Rng;
15
- use rand::SeedableRng;
16
- use rand::seq::SliceRandom;
17
- use rand_xoshiro::Xoshiro256PlusPlus;
18
-
19
- /// A single proximal dendrite: a sparse set of potential synapses onto
20
- /// specific input bit indices, with per-synapse permanence values.
21
- #[derive(Clone)]
22
- pub struct ProximalDendrite {
23
- /// Indices into the input SDR. Length == potential_synapses.
24
- pub inputs: Vec<u32>,
25
- /// Permanence for each potential synapse (same length as `inputs`).
26
- pub perms: Vec<f32>,
27
- }
28
-
29
- pub struct SpatialPoolerConfig {
30
- pub input_bits: usize,
31
- pub n_columns: usize,
32
- /// Size of the random input sample per column.
33
- pub potential_radius: usize,
34
- /// Number of potential synapses per column's proximal dendrite.
35
- pub potential_synapses: usize,
36
- pub connected_threshold: f32,
37
- pub syn_perm_active_inc: f32,
38
- pub syn_perm_inactive_dec: f32,
39
- /// Target fraction of columns active per step (e.g. 0.02 for 2%).
40
- pub sparsity: f32,
41
- /// Duty cycle EMA period.
42
- pub duty_cycle_period: f32,
43
- /// Boost strength. Set to 0.0 to disable boosting.
44
- pub boost_strength: f32,
45
- /// Initial permanence span around the connected threshold.
46
- pub init_perm_span: f32,
47
- }
48
-
49
- impl Default for SpatialPoolerConfig {
50
- fn default() -> Self {
51
- Self {
52
- input_bits: 16384,
53
- n_columns: 2048,
54
- potential_radius: 1024,
55
- potential_synapses: 40,
56
- connected_threshold: 0.5,
57
- syn_perm_active_inc: 0.04,
58
- syn_perm_inactive_dec: 0.008,
59
- sparsity: 0.02,
60
- duty_cycle_period: 1000.0,
61
- boost_strength: 1.0,
62
- init_perm_span: 0.1,
63
- }
64
- }
65
- }
66
-
67
- pub struct SpatialPooler {
68
- pub cfg: SpatialPoolerConfig,
69
- pub columns: Vec<ProximalDendrite>,
70
- /// Exponential moving average of "column was active" per step.
71
- pub active_duty_cycle: Vec<f32>,
72
- /// Exponential moving average of "overlap exceeded threshold" per step.
73
- pub overlap_duty_cycle: Vec<f32>,
74
- /// Boost factor per column.
75
- pub boost: Vec<f32>,
76
- rng: Xoshiro256PlusPlus,
77
- iter_count: u64,
78
- }
79
-
80
- impl SpatialPooler {
81
- pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self {
82
- assert!(cfg.input_bits >= cfg.potential_radius,
83
- "input_bits ({}) must be >= potential_radius ({})",
84
- cfg.input_bits, cfg.potential_radius);
85
- assert!(cfg.potential_radius >= cfg.potential_synapses,
86
- "potential_radius ({}) must be >= potential_synapses ({})",
87
- cfg.potential_radius, cfg.potential_synapses);
88
-
89
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
90
-
91
- let mut columns = Vec::with_capacity(cfg.n_columns);
92
- for _ in 0..cfg.n_columns {
93
- // Sample `potential_radius` distinct input indices, then from those
94
- // pick `potential_synapses` as the actual proximal synapses.
95
- // Using partial Fisher-Yates via shuffle on a pool index range.
96
- let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect();
97
- // Efficient partial shuffle: swap the first `potential_radius`
98
- // items with random items from the rest (Durstenfeld step).
99
- for i in 0..cfg.potential_radius.min(pool.len()) {
100
- let j = rng.gen_range(i..pool.len());
101
- pool.swap(i, j);
102
- }
103
- let window = &mut pool[..cfg.potential_radius];
104
- window.shuffle(&mut rng);
105
- let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec();
106
- inputs.sort_unstable();
107
-
108
- let perms: Vec<f32> = (0..cfg.potential_synapses)
109
- .map(|_| {
110
- let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span);
111
- (cfg.connected_threshold + delta).clamp(0.0, 1.0)
112
- })
113
- .collect();
114
-
115
- columns.push(ProximalDendrite { inputs, perms });
116
- }
117
-
118
- let n = cfg.n_columns;
119
- Self {
120
- cfg,
121
- columns,
122
- active_duty_cycle: vec![0.0; n],
123
- overlap_duty_cycle: vec![0.0; n],
124
- boost: vec![1.0; n],
125
- rng,
126
- iter_count: 0,
127
- }
128
- }
129
-
130
- /// Process one step: compute overlaps, inhibit, learn (if `learn`), update
131
- /// duty cycles and boosts. Returns the set of active column indices.
132
- pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> {
133
- assert_eq!(input.len(), self.cfg.input_bits);
134
-
135
- // 1) Overlap score per column (sum of CONNECTED synapses onto active inputs).
136
- // Also track raw overlap for the overlap-duty-cycle.
137
- let n = self.cfg.n_columns;
138
- let mut overlaps: Vec<f32> = vec![0.0; n];
139
- let mut raw_overlaps: Vec<u32> = vec![0; n];
140
-
141
- for (ci, col) in self.columns.iter().enumerate() {
142
- let mut s: u32 = 0;
143
- for (syn_i, &inp) in col.inputs.iter().enumerate() {
144
- if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold {
145
- s += 1;
146
- }
147
- }
148
- raw_overlaps[ci] = s;
149
- overlaps[ci] = (s as f32) * self.boost[ci];
150
- }
151
-
152
- // 2) Global k-WTA inhibition. Select top-k columns by boosted overlap.
153
- let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1);
154
- let active: Vec<u32> = top_k(&overlaps, k);
155
-
156
- // 3) Hebbian learning on active columns.
157
- if learn {
158
- for &ci in &active {
159
- let col = &mut self.columns[ci as usize];
160
- for (syn_i, &inp) in col.inputs.iter().enumerate() {
161
- if input[inp as usize] {
162
- col.perms[syn_i] =
163
- (col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0);
164
- } else {
165
- col.perms[syn_i] =
166
- (col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0);
167
- }
168
- }
169
- }
170
- }
171
-
172
- // 4) Update duty cycles (EMA with period T -> alpha = 1/T).
173
- let period = self.cfg.duty_cycle_period.max(1.0);
174
- let alpha = 1.0 / period;
175
- // Column is "overlapping enough" if raw overlap >= stimulus_threshold.
176
- // Numenta uses min_overlap; we use 1 as a conservative floor.
177
- let stimulus_threshold = 1.0_f32;
178
-
179
- // Mark active columns.
180
- let mut active_mask = vec![false; n];
181
- for &ci in &active {
182
- active_mask[ci as usize] = true;
183
- }
184
-
185
- for i in 0..n {
186
- let active_sample = if active_mask[i] { 1.0 } else { 0.0 };
187
- let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold {
188
- 1.0
189
- } else {
190
- 0.0
191
- };
192
- self.active_duty_cycle[i] =
193
- (1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample;
194
- self.overlap_duty_cycle[i] =
195
- (1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample;
196
- }
197
-
198
- // 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)).
199
- // Under-used columns (duty < mean) get boost > 1.
200
- if learn && self.cfg.boost_strength > 0.0 {
201
- let mean_duty: f32 =
202
- self.active_duty_cycle.iter().sum::<f32>() / (n as f32);
203
- for i in 0..n {
204
- self.boost[i] =
205
- (-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp();
206
- }
207
-
208
- // 6) Permanence bump for chronically under-stimulated columns.
209
- // If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood,
210
- // bump all permanences by syn_perm_active_inc * 0.1.
211
- // With global inhibition, "neighborhood" = all columns.
212
- let max_overlap_duty = self
213
- .overlap_duty_cycle
214
- .iter()
215
- .cloned()
216
- .fold(0.0_f32, f32::max);
217
- let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty;
218
- if max_overlap_duty > 0.0 {
219
- for i in 0..n {
220
- if self.overlap_duty_cycle[i] < min_pct_overlap_duty {
221
- for p in &mut self.columns[i].perms {
222
- *p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0);
223
- }
224
- }
225
- }
226
- }
227
- }
228
-
229
- self.iter_count = self.iter_count.wrapping_add(1);
230
- let _ = &mut self.rng; // suppress unused-mut when learn=false
231
- active
232
- }
233
- }
234
-
235
- /// Return the indices of the top-k values in `scores`.
236
- /// Ties broken by index order. Output is sorted ascending.
237
- fn top_k(scores: &[f32], k: usize) -> Vec<u32> {
238
- if k == 0 {
239
- return Vec::new();
240
- }
241
- let mut idx: Vec<u32> = (0..scores.len() as u32).collect();
242
- // Partial sort: put top-k at the front by descending score.
243
- // Use select_nth_unstable_by on (desc score, asc index).
244
- idx.select_nth_unstable_by(k - 1, |&a, &b| {
245
- let sa = scores[a as usize];
246
- let sb = scores[b as usize];
247
- // Reverse for descending.
248
- match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) {
249
- std::cmp::Ordering::Equal => a.cmp(&b),
250
- ord => ord,
251
- }
252
- });
253
- let mut winners: Vec<u32> = idx[..k].to_vec();
254
- winners.sort_unstable();
255
- winners
256
- }
257
-
258
- // ---------------------------------------------------------------------------
259
- // Tests
260
- // ---------------------------------------------------------------------------
261
-
262
- #[cfg(test)]
263
- mod tests {
264
- use super::*;
265
- use rand::Rng;
266
- use rand::SeedableRng;
267
- use rand_xoshiro::Xoshiro256PlusPlus;
268
-
269
- #[test]
270
- fn sp_sparsity_exact_2pct() {
271
- // BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41.
272
- // The SP must produce *exactly* that count, no more, no less, and with
273
- // no duplicate indices.
274
- let cfg = SpatialPoolerConfig::default();
275
- let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize;
276
- assert!(expected_k > 0);
277
-
278
- let input_bits = cfg.input_bits;
279
- let mut sp = SpatialPooler::new(cfg, 42);
280
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
281
-
282
- for _ in 0..100 {
283
- // 2% sparse random input SDR.
284
- let on_bits = (0.02 * input_bits as f32) as usize;
285
- let mut sdr = vec![false; input_bits];
286
- for _ in 0..on_bits {
287
- let i = rng.gen_range(0..input_bits);
288
- sdr[i] = true;
289
- }
290
- let active = sp.compute(&sdr, true);
291
- assert_eq!(
292
- active.len(),
293
- expected_k,
294
- "SP must emit exactly {expected_k} active columns"
295
- );
296
- let mut a = active.clone();
297
- a.sort_unstable();
298
- a.dedup();
299
- assert_eq!(a.len(), expected_k);
300
- }
301
- }
302
- }
 
1
+ //! Numenta BAMI-spec Spatial Pooler.
2
+ //!
3
+ //! Implements:
4
+ //! - 2048 (configurable) mini-columns with proximal dendrites
5
+ //! - `potential_synapses` (default 40) synapses per column sampled from
6
+ //! `potential_radius` (default 1024) random input bits
7
+ //! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5
8
+ //! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008
9
+ //! - Global k-WTA inhibition (top `sparsity` fraction of columns)
10
+ //! - Boost factor with exponential duty-cycle tracking (Numenta formula)
11
+ //!
12
+ //! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017).
13
+
14
+ use rand::Rng;
15
+ use rand::SeedableRng;
16
+ use rand::seq::SliceRandom;
17
+ use rand_xoshiro::Xoshiro256PlusPlus;
18
+
19
+ /// A single proximal dendrite: a sparse set of potential synapses onto
20
+ /// specific input bit indices, with per-synapse permanence values.
21
+ #[derive(Clone)]
22
+ pub struct ProximalDendrite {
23
+ /// Indices into the input SDR. Length == potential_synapses.
24
+ pub inputs: Vec<u32>,
25
+ /// Permanence for each potential synapse (same length as `inputs`).
26
+ pub perms: Vec<f32>,
27
+ }
28
+
29
+ pub struct SpatialPoolerConfig {
30
+ pub input_bits: usize,
31
+ pub n_columns: usize,
32
+ /// Size of the random input sample per column.
33
+ pub potential_radius: usize,
34
+ /// Number of potential synapses per column's proximal dendrite.
35
+ pub potential_synapses: usize,
36
+ pub connected_threshold: f32,
37
+ pub syn_perm_active_inc: f32,
38
+ pub syn_perm_inactive_dec: f32,
39
+ /// Target fraction of columns active per step (e.g. 0.02 for 2%).
40
+ pub sparsity: f32,
41
+ /// Duty cycle EMA period.
42
+ pub duty_cycle_period: f32,
43
+ /// Boost strength. Set to 0.0 to disable boosting.
44
+ pub boost_strength: f32,
45
+ /// Initial permanence span around the connected threshold.
46
+ pub init_perm_span: f32,
47
+ }
48
+
49
+ impl Default for SpatialPoolerConfig {
50
+ fn default() -> Self {
51
+ Self {
52
+ input_bits: 16384,
53
+ n_columns: 2048,
54
+ potential_radius: 1024,
55
+ potential_synapses: 40,
56
+ connected_threshold: 0.5,
57
+ syn_perm_active_inc: 0.04,
58
+ syn_perm_inactive_dec: 0.008,
59
+ sparsity: 0.02,
60
+ duty_cycle_period: 1000.0,
61
+ boost_strength: 1.0,
62
+ init_perm_span: 0.1,
63
+ }
64
+ }
65
+ }
66
+
67
+ pub struct SpatialPooler {
68
+ pub cfg: SpatialPoolerConfig,
69
+ pub columns: Vec<ProximalDendrite>,
70
+ /// Exponential moving average of "column was active" per step.
71
+ pub active_duty_cycle: Vec<f32>,
72
+ /// Exponential moving average of "overlap exceeded threshold" per step.
73
+ pub overlap_duty_cycle: Vec<f32>,
74
+ /// Boost factor per column.
75
+ pub boost: Vec<f32>,
76
+ rng: Xoshiro256PlusPlus,
77
+ iter_count: u64,
78
+ }
79
+
80
+ impl SpatialPooler {
81
+ pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self {
82
+ assert!(cfg.input_bits >= cfg.potential_radius,
83
+ "input_bits ({}) must be >= potential_radius ({})",
84
+ cfg.input_bits, cfg.potential_radius);
85
+ assert!(cfg.potential_radius >= cfg.potential_synapses,
86
+ "potential_radius ({}) must be >= potential_synapses ({})",
87
+ cfg.potential_radius, cfg.potential_synapses);
88
+
89
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
90
+
91
+ let mut columns = Vec::with_capacity(cfg.n_columns);
92
+ for _ in 0..cfg.n_columns {
93
+ // Sample `potential_radius` distinct input indices, then from those
94
+ // pick `potential_synapses` as the actual proximal synapses.
95
+ // Using partial Fisher-Yates via shuffle on a pool index range.
96
+ let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect();
97
+ // Efficient partial shuffle: swap the first `potential_radius`
98
+ // items with random items from the rest (Durstenfeld step).
99
+ for i in 0..cfg.potential_radius.min(pool.len()) {
100
+ let j = rng.gen_range(i..pool.len());
101
+ pool.swap(i, j);
102
+ }
103
+ let window = &mut pool[..cfg.potential_radius];
104
+ window.shuffle(&mut rng);
105
+ let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec();
106
+ inputs.sort_unstable();
107
+
108
+ let perms: Vec<f32> = (0..cfg.potential_synapses)
109
+ .map(|_| {
110
+ let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span);
111
+ (cfg.connected_threshold + delta).clamp(0.0, 1.0)
112
+ })
113
+ .collect();
114
+
115
+ columns.push(ProximalDendrite { inputs, perms });
116
+ }
117
+
118
+ let n = cfg.n_columns;
119
+ Self {
120
+ cfg,
121
+ columns,
122
+ active_duty_cycle: vec![0.0; n],
123
+ overlap_duty_cycle: vec![0.0; n],
124
+ boost: vec![1.0; n],
125
+ rng,
126
+ iter_count: 0,
127
+ }
128
+ }
129
+
130
+ /// Process one step: compute overlaps, inhibit, learn (if `learn`), update
131
+ /// duty cycles and boosts. Returns the set of active column indices.
132
+ pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> {
133
+ assert_eq!(input.len(), self.cfg.input_bits);
134
+
135
+ // 1) Overlap score per column (sum of CONNECTED synapses onto active inputs).
136
+ // Also track raw overlap for the overlap-duty-cycle.
137
+ let n = self.cfg.n_columns;
138
+ let mut overlaps: Vec<f32> = vec![0.0; n];
139
+ let mut raw_overlaps: Vec<u32> = vec![0; n];
140
+
141
+ for (ci, col) in self.columns.iter().enumerate() {
142
+ let mut s: u32 = 0;
143
+ for (syn_i, &inp) in col.inputs.iter().enumerate() {
144
+ if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold {
145
+ s += 1;
146
+ }
147
+ }
148
+ raw_overlaps[ci] = s;
149
+ overlaps[ci] = (s as f32) * self.boost[ci];
150
+ }
151
+
152
+ // 2) Global k-WTA inhibition. Select top-k columns by boosted overlap.
153
+ let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1);
154
+ let active: Vec<u32> = top_k(&overlaps, k);
155
+
156
+ // 3) Hebbian learning on active columns.
157
+ if learn {
158
+ for &ci in &active {
159
+ let col = &mut self.columns[ci as usize];
160
+ for (syn_i, &inp) in col.inputs.iter().enumerate() {
161
+ if input[inp as usize] {
162
+ col.perms[syn_i] =
163
+ (col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0);
164
+ } else {
165
+ col.perms[syn_i] =
166
+ (col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0);
167
+ }
168
+ }
169
+ }
170
+ }
171
+
172
+ // 4) Update duty cycles (EMA with period T -> alpha = 1/T).
173
+ let period = self.cfg.duty_cycle_period.max(1.0);
174
+ let alpha = 1.0 / period;
175
+ // Column is "overlapping enough" if raw overlap >= stimulus_threshold.
176
+ // Numenta uses min_overlap; we use 1 as a conservative floor.
177
+ let stimulus_threshold = 1.0_f32;
178
+
179
+ // Mark active columns.
180
+ let mut active_mask = vec![false; n];
181
+ for &ci in &active {
182
+ active_mask[ci as usize] = true;
183
+ }
184
+
185
+ for i in 0..n {
186
+ let active_sample = if active_mask[i] { 1.0 } else { 0.0 };
187
+ let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold {
188
+ 1.0
189
+ } else {
190
+ 0.0
191
+ };
192
+ self.active_duty_cycle[i] =
193
+ (1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample;
194
+ self.overlap_duty_cycle[i] =
195
+ (1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample;
196
+ }
197
+
198
+ // 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)).
199
+ // Under-used columns (duty < mean) get boost > 1.
200
+ if learn && self.cfg.boost_strength > 0.0 {
201
+ let mean_duty: f32 =
202
+ self.active_duty_cycle.iter().sum::<f32>() / (n as f32);
203
+ for i in 0..n {
204
+ self.boost[i] =
205
+ (-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp();
206
+ }
207
+
208
+ // 6) Permanence bump for chronically under-stimulated columns.
209
+ // If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood,
210
+ // bump all permanences by syn_perm_active_inc * 0.1.
211
+ // With global inhibition, "neighborhood" = all columns.
212
+ let max_overlap_duty = self
213
+ .overlap_duty_cycle
214
+ .iter()
215
+ .cloned()
216
+ .fold(0.0_f32, f32::max);
217
+ let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty;
218
+ if max_overlap_duty > 0.0 {
219
+ for i in 0..n {
220
+ if self.overlap_duty_cycle[i] < min_pct_overlap_duty {
221
+ for p in &mut self.columns[i].perms {
222
+ *p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0);
223
+ }
224
+ }
225
+ }
226
+ }
227
+ }
228
+
229
+ self.iter_count = self.iter_count.wrapping_add(1);
230
+ let _ = &mut self.rng; // suppress unused-mut when learn=false
231
+ active
232
+ }
233
+ }
234
+
235
+ /// Return the indices of the top-k values in `scores`.
236
+ /// Ties broken by index order. Output is sorted ascending.
237
+ fn top_k(scores: &[f32], k: usize) -> Vec<u32> {
238
+ if k == 0 {
239
+ return Vec::new();
240
+ }
241
+ let mut idx: Vec<u32> = (0..scores.len() as u32).collect();
242
+ // Partial sort: put top-k at the front by descending score.
243
+ // Use select_nth_unstable_by on (desc score, asc index).
244
+ idx.select_nth_unstable_by(k - 1, |&a, &b| {
245
+ let sa = scores[a as usize];
246
+ let sb = scores[b as usize];
247
+ // Reverse for descending.
248
+ match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) {
249
+ std::cmp::Ordering::Equal => a.cmp(&b),
250
+ ord => ord,
251
+ }
252
+ });
253
+ let mut winners: Vec<u32> = idx[..k].to_vec();
254
+ winners.sort_unstable();
255
+ winners
256
+ }
257
+
258
+ // ---------------------------------------------------------------------------
259
+ // Tests
260
+ // ---------------------------------------------------------------------------
261
+
262
+ #[cfg(test)]
263
+ mod tests {
264
+ use super::*;
265
+ use rand::Rng;
266
+ use rand::SeedableRng;
267
+ use rand_xoshiro::Xoshiro256PlusPlus;
268
+
269
+ #[test]
270
+ fn sp_sparsity_exact_2pct() {
271
+ // BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41.
272
+ // The SP must produce *exactly* that count, no more, no less, and with
273
+ // no duplicate indices.
274
+ let cfg = SpatialPoolerConfig::default();
275
+ let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize;
276
+ assert!(expected_k > 0);
277
+
278
+ let input_bits = cfg.input_bits;
279
+ let mut sp = SpatialPooler::new(cfg, 42);
280
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
281
+
282
+ for _ in 0..100 {
283
+ // 2% sparse random input SDR.
284
+ let on_bits = (0.02 * input_bits as f32) as usize;
285
+ let mut sdr = vec![false; input_bits];
286
+ for _ in 0..on_bits {
287
+ let i = rng.gen_range(0..input_bits);
288
+ sdr[i] = true;
289
+ }
290
+ let active = sp.compute(&sdr, true);
291
+ assert_eq!(
292
+ active.len(),
293
+ expected_k,
294
+ "SP must emit exactly {expected_k} active columns"
295
+ );
296
+ let mut a = active.clone();
297
+ a.sort_unstable();
298
+ a.dedup();
299
+ assert_eq!(a.len(), expected_k);
300
+ }
301
+ }
302
+ }
overlay/htm_rust/src/tm.rs CHANGED
@@ -1,545 +1,545 @@
1
- //! Numenta BAMI-spec Temporal Memory.
2
- //!
3
- //! Key parameters (Numenta defaults):
4
- //! - cells_per_column = 32
5
- //! - max_segments_per_cell = 255
6
- //! - max_synapses_per_segment = 32
7
- //! - activation_threshold = 15 (CONNECTED synapses onto active cells)
8
- //! - learning_threshold = 13 (POTENTIAL synapses onto active cells)
9
- //! (often called `minThreshold` / match threshold in BAMI)
10
- //! - initial_permanence = 0.21
11
- //! - connected_permanence = 0.50
12
- //! - permanence_increment = 0.10
13
- //! - permanence_decrement = 0.10
14
- //! - predicted_segment_decrement = 0.10 (decay for segments that predicted
15
- //! inactive columns; called `predictedSegmentDecrement` in BAMI)
16
- //! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg)
17
- //!
18
- //! Algorithm (one step):
19
- //! Given `active_columns` from the Spatial Pooler, and segment activity
20
- //! caches `active_segments` and `matching_segments` computed *at the end of
21
- //! the previous step*:
22
- //!
23
- //! 1. For each active column:
24
- //! - If it contains any predicted cell (any cell with an active segment
25
- //! from the previous depolarization), mark those cells active and
26
- //! learn on the segment that predicted it.
27
- //! - Else BURST the column: mark all cells in it active, and grow a new
28
- //! segment on the best-matching cell in the column (or, if none,
29
- //! on the cell with the fewest segments).
30
- //! 2. For every column that was predicted but did NOT become active
31
- //! (matching segments on inactive columns), apply the
32
- //! `predicted_segment_decrement` decay so spurious predictions fade.
33
- //! 3. Winner cells = active cells chosen for learning (1 per active column).
34
- //! 4. Compute segment activity for NEXT step:
35
- //! - A segment's CONNECTED activity = #synapses with perm >= connected_perm
36
- //! whose presynaptic cell is in `active_cells`. If >= activation_threshold
37
- //! -> segment is "active" -> its cell is "predicted".
38
- //! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is
39
- //! in `active_cells` (regardless of permanence). If >= learning_threshold
40
- //! -> segment is "matching".
41
- //!
42
- //! Anomaly score = (active columns with no prior predicted cells)
43
- //! / (# active columns).
44
-
45
- use rand::Rng;
46
- use rand::SeedableRng;
47
- use rand_xoshiro::Xoshiro256PlusPlus;
48
-
49
- type CellIdx = u32;
50
- type SegmentIdx = u32;
51
-
52
- #[derive(Clone)]
53
- pub struct Synapse {
54
- pub presynaptic_cell: CellIdx,
55
- pub permanence: f32,
56
- }
57
-
58
- #[derive(Clone)]
59
- pub struct Segment {
60
- pub cell: CellIdx,
61
- pub synapses: Vec<Synapse>,
62
- /// Cached counters; recomputed each step.
63
- pub num_active_connected: u32,
64
- pub num_active_potential: u32,
65
- /// Simple "last iter touched" stat for least-used cell selection.
66
- pub last_used_iteration: u64,
67
- }
68
-
69
- pub struct TemporalMemoryConfig {
70
- pub n_columns: usize,
71
- pub cells_per_column: usize,
72
- pub activation_threshold: u32,
73
- pub learning_threshold: u32,
74
- pub initial_permanence: f32,
75
- pub connected_permanence: f32,
76
- pub permanence_increment: f32,
77
- pub permanence_decrement: f32,
78
- pub predicted_segment_decrement: f32,
79
- pub max_segments_per_cell: usize,
80
- pub max_synapses_per_segment: usize,
81
- pub max_new_synapse_count: usize,
82
- }
83
-
84
- impl Default for TemporalMemoryConfig {
85
- fn default() -> Self {
86
- Self {
87
- n_columns: 2048,
88
- cells_per_column: 32,
89
- activation_threshold: 15,
90
- learning_threshold: 13,
91
- initial_permanence: 0.21,
92
- connected_permanence: 0.50,
93
- permanence_increment: 0.10,
94
- permanence_decrement: 0.10,
95
- predicted_segment_decrement: 0.10,
96
- max_segments_per_cell: 255,
97
- max_synapses_per_segment: 32,
98
- max_new_synapse_count: 20,
99
- }
100
- }
101
- }
102
-
103
- pub struct TemporalMemory {
104
- pub cfg: TemporalMemoryConfig,
105
- /// All segments in the region. Indexed by SegmentIdx.
106
- pub segments: Vec<Segment>,
107
- /// For each cell, the list of segments that belong to it.
108
- pub cell_segments: Vec<Vec<SegmentIdx>>,
109
- /// Active cells in the current step.
110
- pub active_cells: Vec<bool>,
111
- /// Winner cells (subset of active_cells, 1 per active column) for learning.
112
- pub winner_cells: Vec<bool>,
113
- /// Predictive cells for the current step = cells whose segment became
114
- /// active at the end of the previous step.
115
- pub predictive_cells: Vec<bool>,
116
- /// Cached list of segment indices that were "active" last compute().
117
- active_segments_prev: Vec<SegmentIdx>,
118
- /// Cached list of segment indices that were "matching" last compute().
119
- matching_segments_prev: Vec<SegmentIdx>,
120
- rng: Xoshiro256PlusPlus,
121
- iter_count: u64,
122
- }
123
-
124
- impl TemporalMemory {
125
- pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self {
126
- let total = cfg.n_columns * cfg.cells_per_column;
127
- Self {
128
- cell_segments: vec![Vec::new(); total],
129
- active_cells: vec![false; total],
130
- winner_cells: vec![false; total],
131
- predictive_cells: vec![false; total],
132
- cfg,
133
- segments: Vec::new(),
134
- active_segments_prev: Vec::new(),
135
- matching_segments_prev: Vec::new(),
136
- rng: Xoshiro256PlusPlus::seed_from_u64(seed),
137
- iter_count: 0,
138
- }
139
- }
140
-
141
- pub fn reset(&mut self) {
142
- for v in self.active_cells.iter_mut() { *v = false; }
143
- for v in self.winner_cells.iter_mut() { *v = false; }
144
- for v in self.predictive_cells.iter_mut() { *v = false; }
145
- self.active_segments_prev.clear();
146
- self.matching_segments_prev.clear();
147
- }
148
-
149
- #[inline]
150
- fn col_of(&self, cell: CellIdx) -> usize {
151
- (cell as usize) / self.cfg.cells_per_column
152
- }
153
-
154
- #[inline]
155
- fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> {
156
- let base = (col * self.cfg.cells_per_column) as CellIdx;
157
- base..(base + self.cfg.cells_per_column as CellIdx)
158
- }
159
-
160
- /// Process one step.
161
- ///
162
- /// `active_columns` is the set of column indices activated by the Spatial
163
- /// Pooler this step. Returns the anomaly score in [0, 1].
164
- pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 {
165
- self.iter_count = self.iter_count.wrapping_add(1);
166
-
167
- // Snapshot previous-step cell activity (for learning on segments).
168
- let prev_active_cells = self.active_cells.clone();
169
- let prev_winner_cells = self.winner_cells.clone();
170
-
171
- // Move current "predictive" (computed at the end of the last step)
172
- // into local variables; we'll overwrite predictive_cells later.
173
- let predictive_prev = self.predictive_cells.clone();
174
-
175
- // Group active segments and matching segments by column of their
176
- // owning cell, for the columns that are active this step.
177
- let n_cols = self.cfg.n_columns;
178
-
179
- // active_segs_by_col[col] = segment indices whose cell is in col and
180
- // which were "active" in the previous depolarization.
181
- // matching_segs_by_col[col] = similarly for "matching".
182
- let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
183
- let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
184
- for &seg in &self.active_segments_prev {
185
- let col = self.col_of(self.segments[seg as usize].cell);
186
- active_segs_by_col[col].push(seg);
187
- }
188
- for &seg in &self.matching_segments_prev {
189
- let col = self.col_of(self.segments[seg as usize].cell);
190
- matching_segs_by_col[col].push(seg);
191
- }
192
-
193
- // Columns that are active this step (for O(1) lookup).
194
- let mut active_col_mask = vec![false; n_cols];
195
- for &c in active_columns { active_col_mask[c as usize] = true; }
196
-
197
- // Zero out current cell activations.
198
- for v in self.active_cells.iter_mut() { *v = false; }
199
- for v in self.winner_cells.iter_mut() { *v = false; }
200
-
201
- // Track anomaly.
202
- let mut unpredicted_cols = 0u32;
203
-
204
- // We'll collect (segment, learn_mode) pairs for segment reinforcement
205
- // so we can batch-apply permanence adjustments using prev_active_cells.
206
- // learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched"
207
- enum LearnOp {
208
- Reinforce(SegmentIdx), // correctly predicted
209
- Grow { // bursting column: grow on chosen segment
210
- segment: SegmentIdx,
211
- #[allow(dead_code)]
212
- winner_cell: CellIdx,
213
- },
214
- Punish(SegmentIdx), // matching segment on inactive column
215
- }
216
- let mut ops: Vec<LearnOp> = Vec::new();
217
-
218
- // ---- 1) Process active columns ----
219
- for &col in active_columns {
220
- let col = col as usize;
221
- let active_segs = &active_segs_by_col[col];
222
- if !active_segs.is_empty() {
223
- // "Activate predicted column": each cell with an active segment
224
- // becomes active and is a winner; reinforce that segment.
225
- let mut seen_cells: Vec<CellIdx> = Vec::new();
226
- for &seg_i in active_segs {
227
- let seg = &self.segments[seg_i as usize];
228
- let cell = seg.cell;
229
- if !seen_cells.contains(&cell) {
230
- self.active_cells[cell as usize] = true;
231
- self.winner_cells[cell as usize] = true;
232
- seen_cells.push(cell);
233
- }
234
- if learn {
235
- ops.push(LearnOp::Reinforce(seg_i));
236
- }
237
- }
238
- } else {
239
- // ----- BURST -----
240
- unpredicted_cols += 1;
241
- for c in self.cells_in_col(col) {
242
- self.active_cells[c as usize] = true;
243
- }
244
- // Pick a winner cell + segment for learning.
245
- if learn {
246
- let matching = &matching_segs_by_col[col];
247
- let (winner_cell, target_segment) = if !matching.is_empty() {
248
- // Best-matching segment = highest num_active_potential.
249
- let mut best = matching[0];
250
- let mut best_score = self.segments[best as usize].num_active_potential;
251
- for &s in &matching[1..] {
252
- let score = self.segments[s as usize].num_active_potential;
253
- if score > best_score {
254
- best_score = score;
255
- best = s;
256
- }
257
- }
258
- let wc = self.segments[best as usize].cell;
259
- (wc, Some(best))
260
- } else {
261
- // Least-used cell in column, then grow a new segment.
262
- let winner = self.least_used_cell(col);
263
- (winner, None)
264
- };
265
- self.winner_cells[winner_cell as usize] = true;
266
- let segment_id = match target_segment {
267
- Some(s) => s,
268
- None => {
269
- // Create a fresh empty segment on winner cell.
270
- self.create_segment(winner_cell)
271
- }
272
- };
273
- ops.push(LearnOp::Grow { segment: segment_id, winner_cell });
274
- } else {
275
- // No learning: still pick some winner cell (arbitrary)
276
- // so downstream code that inspects winner_cells isn't empty.
277
- let matching = &matching_segs_by_col[col];
278
- let winner_cell = if !matching.is_empty() {
279
- self.segments[matching[0] as usize].cell
280
- } else {
281
- self.least_used_cell(col)
282
- };
283
- self.winner_cells[winner_cell as usize] = true;
284
- }
285
- }
286
- }
287
-
288
- // ---- 2) Punish matching segments on INACTIVE columns ----
289
- if learn && self.cfg.predicted_segment_decrement > 0.0 {
290
- for &seg_i in &self.matching_segments_prev {
291
- let col = self.col_of(self.segments[seg_i as usize].cell);
292
- if !active_col_mask[col] {
293
- ops.push(LearnOp::Punish(seg_i));
294
- }
295
- }
296
- }
297
-
298
- // ---- 3) Apply learning ----
299
- if learn {
300
- for op in ops {
301
- match op {
302
- LearnOp::Reinforce(seg_i) => {
303
- self.reinforce_segment(seg_i, &prev_active_cells);
304
- // Optionally grow up to N new synapses to winner cells
305
- // of the previous step.
306
- self.grow_synapses_on_segment(seg_i, &prev_winner_cells);
307
- }
308
- LearnOp::Grow { segment, winner_cell: _ } => {
309
- self.reinforce_segment(segment, &prev_active_cells);
310
- self.grow_synapses_on_segment(segment, &prev_winner_cells);
311
- }
312
- LearnOp::Punish(seg_i) => {
313
- let dec = self.cfg.predicted_segment_decrement;
314
- for syn in &mut self.segments[seg_i as usize].synapses {
315
- if prev_active_cells[syn.presynaptic_cell as usize] {
316
- syn.permanence = (syn.permanence - dec).max(0.0);
317
- }
318
- }
319
- }
320
- }
321
- }
322
- }
323
-
324
- // ---- 4) Compute segment activity & predictive cells for NEXT step ----
325
- // We have to use the *current* active_cells (just set above).
326
- let mut next_active_segs: Vec<SegmentIdx> = Vec::new();
327
- let mut next_matching_segs: Vec<SegmentIdx> = Vec::new();
328
- for v in self.predictive_cells.iter_mut() { *v = false; }
329
-
330
- let conn = self.cfg.connected_permanence;
331
- let act_thr = self.cfg.activation_threshold;
332
- let learn_thr = self.cfg.learning_threshold;
333
-
334
- for (seg_i, seg) in self.segments.iter_mut().enumerate() {
335
- let mut n_conn: u32 = 0;
336
- let mut n_pot: u32 = 0;
337
- for syn in &seg.synapses {
338
- if self.active_cells[syn.presynaptic_cell as usize] {
339
- n_pot += 1;
340
- if syn.permanence >= conn { n_conn += 1; }
341
- }
342
- }
343
- seg.num_active_connected = n_conn;
344
- seg.num_active_potential = n_pot;
345
- if n_conn >= act_thr {
346
- next_active_segs.push(seg_i as SegmentIdx);
347
- self.predictive_cells[seg.cell as usize] = true;
348
- }
349
- if n_pot >= learn_thr {
350
- next_matching_segs.push(seg_i as SegmentIdx);
351
- }
352
- }
353
- self.active_segments_prev = next_active_segs;
354
- self.matching_segments_prev = next_matching_segs;
355
-
356
- // Keep predictive_prev unused-guard; we no longer need it but
357
- // retained to document intent.
358
- let _ = predictive_prev;
359
-
360
- // Anomaly.
361
- if active_columns.is_empty() {
362
- 0.0
363
- } else {
364
- (unpredicted_cols as f32) / (active_columns.len() as f32)
365
- }
366
- }
367
-
368
- /// Reinforce synapses on `seg`: +inc if presynaptic is active last step,
369
- /// -dec otherwise.
370
- fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) {
371
- let inc = self.cfg.permanence_increment;
372
- let dec = self.cfg.permanence_decrement;
373
- let seg = &mut self.segments[seg_i as usize];
374
- seg.last_used_iteration = self.iter_count;
375
- for syn in &mut seg.synapses {
376
- if prev_active_cells[syn.presynaptic_cell as usize] {
377
- syn.permanence = (syn.permanence + inc).min(1.0);
378
- } else {
379
- syn.permanence = (syn.permanence - dec).max(0.0);
380
- }
381
- }
382
- }
383
-
384
- /// Grow up to `max_new_synapse_count - current_potential` new synapses
385
- /// from previous winner cells that are not already connected to this seg.
386
- fn grow_synapses_on_segment(
387
- &mut self,
388
- seg_i: SegmentIdx,
389
- prev_winner_cells: &[bool],
390
- ) {
391
- let initial_perm = self.cfg.initial_permanence;
392
- let cap = self.cfg.max_synapses_per_segment;
393
- let max_new = self.cfg.max_new_synapse_count;
394
-
395
- // Gather candidate cells (prev winners not already presynaptic to this seg).
396
- let already: Vec<CellIdx> = self.segments[seg_i as usize]
397
- .synapses
398
- .iter()
399
- .map(|s| s.presynaptic_cell)
400
- .collect();
401
- let mut candidates: Vec<CellIdx> = Vec::new();
402
- for (cell_i, &b) in prev_winner_cells.iter().enumerate() {
403
- if b && !already.contains(&(cell_i as CellIdx)) {
404
- candidates.push(cell_i as CellIdx);
405
- }
406
- }
407
-
408
- // How many can we add?
409
- let current_len = self.segments[seg_i as usize].synapses.len();
410
- let room = cap.saturating_sub(current_len);
411
- let mut to_add = max_new.min(candidates.len()).min(room);
412
-
413
- // Random sample without replacement from candidates.
414
- while to_add > 0 {
415
- let idx = self.rng.gen_range(0..candidates.len());
416
- let pre = candidates.swap_remove(idx);
417
- self.segments[seg_i as usize].synapses.push(Synapse {
418
- presynaptic_cell: pre,
419
- permanence: initial_perm,
420
- });
421
- to_add -= 1;
422
- }
423
- }
424
-
425
- fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx {
426
- // Enforce per-cell segment cap by evicting least-recently-used segment
427
- // if necessary.
428
- let cell_segs = &mut self.cell_segments[cell as usize];
429
- if cell_segs.len() >= self.cfg.max_segments_per_cell {
430
- // Find LRU segment.
431
- let (lru_pos, &lru_id) = cell_segs
432
- .iter()
433
- .enumerate()
434
- .min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration)
435
- .expect("cell_segs non-empty");
436
- // Clear that segment in place and reuse its index.
437
- self.segments[lru_id as usize].synapses.clear();
438
- self.segments[lru_id as usize].num_active_connected = 0;
439
- self.segments[lru_id as usize].num_active_potential = 0;
440
- self.segments[lru_id as usize].last_used_iteration = self.iter_count;
441
- // Keep at same position in cell_segs.
442
- let _ = lru_pos;
443
- return lru_id;
444
- }
445
-
446
- let new_id = self.segments.len() as SegmentIdx;
447
- self.segments.push(Segment {
448
- cell,
449
- synapses: Vec::with_capacity(self.cfg.max_new_synapse_count),
450
- num_active_connected: 0,
451
- num_active_potential: 0,
452
- last_used_iteration: self.iter_count,
453
- });
454
- cell_segs.push(new_id);
455
- new_id
456
- }
457
-
458
- fn least_used_cell(&mut self, col: usize) -> CellIdx {
459
- // Cell with the fewest segments; break ties randomly.
460
- let mut min_segs = usize::MAX;
461
- let mut candidates: Vec<CellIdx> = Vec::new();
462
- for c in self.cells_in_col(col) {
463
- let n = self.cell_segments[c as usize].len();
464
- if n < min_segs {
465
- min_segs = n;
466
- candidates.clear();
467
- candidates.push(c);
468
- } else if n == min_segs {
469
- candidates.push(c);
470
- }
471
- }
472
- let idx = self.rng.gen_range(0..candidates.len());
473
- candidates[idx]
474
- }
475
- }
476
-
477
- // ---------------------------------------------------------------------------
478
- // Tests
479
- // ---------------------------------------------------------------------------
480
-
481
- #[cfg(test)]
482
- mod tests {
483
- use super::*;
484
- use crate::sp::{SpatialPooler, SpatialPoolerConfig};
485
- use rand::Rng;
486
- use rand::SeedableRng;
487
- use rand_xoshiro::Xoshiro256PlusPlus;
488
-
489
- #[test]
490
- fn tm_learns_repeating_sequence() {
491
- // Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down.
492
- let cfg = SpatialPoolerConfig::default();
493
- let mut sp = SpatialPooler::new(cfg, 123);
494
- let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456);
495
-
496
- // Build 3 fixed random SDRs of 2% sparsity.
497
- let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
498
- let input_bits = sp.cfg.input_bits;
499
- let make_sdr = |rng: &mut Xoshiro256PlusPlus| {
500
- let mut v = vec![false; input_bits];
501
- let on = (0.02 * input_bits as f32) as usize;
502
- let mut placed = 0;
503
- while placed < on {
504
- let i = rng.gen_range(0..input_bits);
505
- if !v[i] {
506
- v[i] = true;
507
- placed += 1;
508
- }
509
- }
510
- v
511
- };
512
- let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)];
513
-
514
- // Warm up SP first so that columns are reliable for each symbol.
515
- for _ in 0..200 {
516
- for s in &seqs {
517
- sp.compute(s, true);
518
- }
519
- }
520
-
521
- // Reset TM so prediction state is clean.
522
- tm.reset();
523
-
524
- // Record anomaly over a window early and late.
525
- let mut early_anoms: Vec<f32> = Vec::new();
526
- let mut late_anoms: Vec<f32> = Vec::new();
527
- for iter in 0..250 {
528
- for s in &seqs {
529
- let active = sp.compute(s, false);
530
- let anomaly = tm.compute(&active, true);
531
- if iter == 10 { early_anoms.push(anomaly); }
532
- if iter == 249 { late_anoms.push(anomaly); }
533
- }
534
- }
535
-
536
- let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32);
537
- let early = mean(&early_anoms);
538
- let late = mean(&late_anoms);
539
- println!("early_anomaly={early}, late_anomaly={late}");
540
- assert!(
541
- late < 0.5 * early + 1e-6,
542
- "late anomaly ({late}) should be < 0.5 * early anomaly ({early})"
543
- );
544
- }
545
- }
 
1
+ //! Numenta BAMI-spec Temporal Memory.
2
+ //!
3
+ //! Key parameters (Numenta defaults):
4
+ //! - cells_per_column = 32
5
+ //! - max_segments_per_cell = 255
6
+ //! - max_synapses_per_segment = 32
7
+ //! - activation_threshold = 15 (CONNECTED synapses onto active cells)
8
+ //! - learning_threshold = 13 (POTENTIAL synapses onto active cells)
9
+ //! (often called `minThreshold` / match threshold in BAMI)
10
+ //! - initial_permanence = 0.21
11
+ //! - connected_permanence = 0.50
12
+ //! - permanence_increment = 0.10
13
+ //! - permanence_decrement = 0.10
14
+ //! - predicted_segment_decrement = 0.10 (decay for segments that predicted
15
+ //! inactive columns; called `predictedSegmentDecrement` in BAMI)
16
+ //! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg)
17
+ //!
18
+ //! Algorithm (one step):
19
+ //! Given `active_columns` from the Spatial Pooler, and segment activity
20
+ //! caches `active_segments` and `matching_segments` computed *at the end of
21
+ //! the previous step*:
22
+ //!
23
+ //! 1. For each active column:
24
+ //! - If it contains any predicted cell (any cell with an active segment
25
+ //! from the previous depolarization), mark those cells active and
26
+ //! learn on the segment that predicted it.
27
+ //! - Else BURST the column: mark all cells in it active, and grow a new
28
+ //! segment on the best-matching cell in the column (or, if none,
29
+ //! on the cell with the fewest segments).
30
+ //! 2. For every column that was predicted but did NOT become active
31
+ //! (matching segments on inactive columns), apply the
32
+ //! `predicted_segment_decrement` decay so spurious predictions fade.
33
+ //! 3. Winner cells = active cells chosen for learning (1 per active column).
34
+ //! 4. Compute segment activity for NEXT step:
35
+ //! - A segment's CONNECTED activity = #synapses with perm >= connected_perm
36
+ //! whose presynaptic cell is in `active_cells`. If >= activation_threshold
37
+ //! -> segment is "active" -> its cell is "predicted".
38
+ //! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is
39
+ //! in `active_cells` (regardless of permanence). If >= learning_threshold
40
+ //! -> segment is "matching".
41
+ //!
42
+ //! Anomaly score = (active columns with no prior predicted cells)
43
+ //! / (# active columns).
44
+
45
+ use rand::Rng;
46
+ use rand::SeedableRng;
47
+ use rand_xoshiro::Xoshiro256PlusPlus;
48
+
49
+ type CellIdx = u32;
50
+ type SegmentIdx = u32;
51
+
52
+ #[derive(Clone)]
53
+ pub struct Synapse {
54
+ pub presynaptic_cell: CellIdx,
55
+ pub permanence: f32,
56
+ }
57
+
58
+ #[derive(Clone)]
59
+ pub struct Segment {
60
+ pub cell: CellIdx,
61
+ pub synapses: Vec<Synapse>,
62
+ /// Cached counters; recomputed each step.
63
+ pub num_active_connected: u32,
64
+ pub num_active_potential: u32,
65
+ /// Simple "last iter touched" stat for least-used cell selection.
66
+ pub last_used_iteration: u64,
67
+ }
68
+
69
+ pub struct TemporalMemoryConfig {
70
+ pub n_columns: usize,
71
+ pub cells_per_column: usize,
72
+ pub activation_threshold: u32,
73
+ pub learning_threshold: u32,
74
+ pub initial_permanence: f32,
75
+ pub connected_permanence: f32,
76
+ pub permanence_increment: f32,
77
+ pub permanence_decrement: f32,
78
+ pub predicted_segment_decrement: f32,
79
+ pub max_segments_per_cell: usize,
80
+ pub max_synapses_per_segment: usize,
81
+ pub max_new_synapse_count: usize,
82
+ }
83
+
84
+ impl Default for TemporalMemoryConfig {
85
+ fn default() -> Self {
86
+ Self {
87
+ n_columns: 2048,
88
+ cells_per_column: 32,
89
+ activation_threshold: 15,
90
+ learning_threshold: 13,
91
+ initial_permanence: 0.21,
92
+ connected_permanence: 0.50,
93
+ permanence_increment: 0.10,
94
+ permanence_decrement: 0.10,
95
+ predicted_segment_decrement: 0.10,
96
+ max_segments_per_cell: 255,
97
+ max_synapses_per_segment: 32,
98
+ max_new_synapse_count: 20,
99
+ }
100
+ }
101
+ }
102
+
103
+ pub struct TemporalMemory {
104
+ pub cfg: TemporalMemoryConfig,
105
+ /// All segments in the region. Indexed by SegmentIdx.
106
+ pub segments: Vec<Segment>,
107
+ /// For each cell, the list of segments that belong to it.
108
+ pub cell_segments: Vec<Vec<SegmentIdx>>,
109
+ /// Active cells in the current step.
110
+ pub active_cells: Vec<bool>,
111
+ /// Winner cells (subset of active_cells, 1 per active column) for learning.
112
+ pub winner_cells: Vec<bool>,
113
+ /// Predictive cells for the current step = cells whose segment became
114
+ /// active at the end of the previous step.
115
+ pub predictive_cells: Vec<bool>,
116
+ /// Cached list of segment indices that were "active" last compute().
117
+ active_segments_prev: Vec<SegmentIdx>,
118
+ /// Cached list of segment indices that were "matching" last compute().
119
+ matching_segments_prev: Vec<SegmentIdx>,
120
+ rng: Xoshiro256PlusPlus,
121
+ iter_count: u64,
122
+ }
123
+
124
+ impl TemporalMemory {
125
+ pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self {
126
+ let total = cfg.n_columns * cfg.cells_per_column;
127
+ Self {
128
+ cell_segments: vec![Vec::new(); total],
129
+ active_cells: vec![false; total],
130
+ winner_cells: vec![false; total],
131
+ predictive_cells: vec![false; total],
132
+ cfg,
133
+ segments: Vec::new(),
134
+ active_segments_prev: Vec::new(),
135
+ matching_segments_prev: Vec::new(),
136
+ rng: Xoshiro256PlusPlus::seed_from_u64(seed),
137
+ iter_count: 0,
138
+ }
139
+ }
140
+
141
+ pub fn reset(&mut self) {
142
+ for v in self.active_cells.iter_mut() { *v = false; }
143
+ for v in self.winner_cells.iter_mut() { *v = false; }
144
+ for v in self.predictive_cells.iter_mut() { *v = false; }
145
+ self.active_segments_prev.clear();
146
+ self.matching_segments_prev.clear();
147
+ }
148
+
149
+ #[inline]
150
+ fn col_of(&self, cell: CellIdx) -> usize {
151
+ (cell as usize) / self.cfg.cells_per_column
152
+ }
153
+
154
+ #[inline]
155
+ fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> {
156
+ let base = (col * self.cfg.cells_per_column) as CellIdx;
157
+ base..(base + self.cfg.cells_per_column as CellIdx)
158
+ }
159
+
160
+ /// Process one step.
161
+ ///
162
+ /// `active_columns` is the set of column indices activated by the Spatial
163
+ /// Pooler this step. Returns the anomaly score in [0, 1].
164
+ pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 {
165
+ self.iter_count = self.iter_count.wrapping_add(1);
166
+
167
+ // Snapshot previous-step cell activity (for learning on segments).
168
+ let prev_active_cells = self.active_cells.clone();
169
+ let prev_winner_cells = self.winner_cells.clone();
170
+
171
+ // Move current "predictive" (computed at the end of the last step)
172
+ // into local variables; we'll overwrite predictive_cells later.
173
+ let predictive_prev = self.predictive_cells.clone();
174
+
175
+ // Group active segments and matching segments by column of their
176
+ // owning cell, for the columns that are active this step.
177
+ let n_cols = self.cfg.n_columns;
178
+
179
+ // active_segs_by_col[col] = segment indices whose cell is in col and
180
+ // which were "active" in the previous depolarization.
181
+ // matching_segs_by_col[col] = similarly for "matching".
182
+ let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
183
+ let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
184
+ for &seg in &self.active_segments_prev {
185
+ let col = self.col_of(self.segments[seg as usize].cell);
186
+ active_segs_by_col[col].push(seg);
187
+ }
188
+ for &seg in &self.matching_segments_prev {
189
+ let col = self.col_of(self.segments[seg as usize].cell);
190
+ matching_segs_by_col[col].push(seg);
191
+ }
192
+
193
+ // Columns that are active this step (for O(1) lookup).
194
+ let mut active_col_mask = vec![false; n_cols];
195
+ for &c in active_columns { active_col_mask[c as usize] = true; }
196
+
197
+ // Zero out current cell activations.
198
+ for v in self.active_cells.iter_mut() { *v = false; }
199
+ for v in self.winner_cells.iter_mut() { *v = false; }
200
+
201
+ // Track anomaly.
202
+ let mut unpredicted_cols = 0u32;
203
+
204
+ // We'll collect (segment, learn_mode) pairs for segment reinforcement
205
+ // so we can batch-apply permanence adjustments using prev_active_cells.
206
+ // learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched"
207
+ enum LearnOp {
208
+ Reinforce(SegmentIdx), // correctly predicted
209
+ Grow { // bursting column: grow on chosen segment
210
+ segment: SegmentIdx,
211
+ #[allow(dead_code)]
212
+ winner_cell: CellIdx,
213
+ },
214
+ Punish(SegmentIdx), // matching segment on inactive column
215
+ }
216
+ let mut ops: Vec<LearnOp> = Vec::new();
217
+
218
+ // ---- 1) Process active columns ----
219
+ for &col in active_columns {
220
+ let col = col as usize;
221
+ let active_segs = &active_segs_by_col[col];
222
+ if !active_segs.is_empty() {
223
+ // "Activate predicted column": each cell with an active segment
224
+ // becomes active and is a winner; reinforce that segment.
225
+ let mut seen_cells: Vec<CellIdx> = Vec::new();
226
+ for &seg_i in active_segs {
227
+ let seg = &self.segments[seg_i as usize];
228
+ let cell = seg.cell;
229
+ if !seen_cells.contains(&cell) {
230
+ self.active_cells[cell as usize] = true;
231
+ self.winner_cells[cell as usize] = true;
232
+ seen_cells.push(cell);
233
+ }
234
+ if learn {
235
+ ops.push(LearnOp::Reinforce(seg_i));
236
+ }
237
+ }
238
+ } else {
239
+ // ----- BURST -----
240
+ unpredicted_cols += 1;
241
+ for c in self.cells_in_col(col) {
242
+ self.active_cells[c as usize] = true;
243
+ }
244
+ // Pick a winner cell + segment for learning.
245
+ if learn {
246
+ let matching = &matching_segs_by_col[col];
247
+ let (winner_cell, target_segment) = if !matching.is_empty() {
248
+ // Best-matching segment = highest num_active_potential.
249
+ let mut best = matching[0];
250
+ let mut best_score = self.segments[best as usize].num_active_potential;
251
+ for &s in &matching[1..] {
252
+ let score = self.segments[s as usize].num_active_potential;
253
+ if score > best_score {
254
+ best_score = score;
255
+ best = s;
256
+ }
257
+ }
258
+ let wc = self.segments[best as usize].cell;
259
+ (wc, Some(best))
260
+ } else {
261
+ // Least-used cell in column, then grow a new segment.
262
+ let winner = self.least_used_cell(col);
263
+ (winner, None)
264
+ };
265
+ self.winner_cells[winner_cell as usize] = true;
266
+ let segment_id = match target_segment {
267
+ Some(s) => s,
268
+ None => {
269
+ // Create a fresh empty segment on winner cell.
270
+ self.create_segment(winner_cell)
271
+ }
272
+ };
273
+ ops.push(LearnOp::Grow { segment: segment_id, winner_cell });
274
+ } else {
275
+ // No learning: still pick some winner cell (arbitrary)
276
+ // so downstream code that inspects winner_cells isn't empty.
277
+ let matching = &matching_segs_by_col[col];
278
+ let winner_cell = if !matching.is_empty() {
279
+ self.segments[matching[0] as usize].cell
280
+ } else {
281
+ self.least_used_cell(col)
282
+ };
283
+ self.winner_cells[winner_cell as usize] = true;
284
+ }
285
+ }
286
+ }
287
+
288
+ // ---- 2) Punish matching segments on INACTIVE columns ----
289
+ if learn && self.cfg.predicted_segment_decrement > 0.0 {
290
+ for &seg_i in &self.matching_segments_prev {
291
+ let col = self.col_of(self.segments[seg_i as usize].cell);
292
+ if !active_col_mask[col] {
293
+ ops.push(LearnOp::Punish(seg_i));
294
+ }
295
+ }
296
+ }
297
+
298
+ // ---- 3) Apply learning ----
299
+ if learn {
300
+ for op in ops {
301
+ match op {
302
+ LearnOp::Reinforce(seg_i) => {
303
+ self.reinforce_segment(seg_i, &prev_active_cells);
304
+ // Optionally grow up to N new synapses to winner cells
305
+ // of the previous step.
306
+ self.grow_synapses_on_segment(seg_i, &prev_winner_cells);
307
+ }
308
+ LearnOp::Grow { segment, winner_cell: _ } => {
309
+ self.reinforce_segment(segment, &prev_active_cells);
310
+ self.grow_synapses_on_segment(segment, &prev_winner_cells);
311
+ }
312
+ LearnOp::Punish(seg_i) => {
313
+ let dec = self.cfg.predicted_segment_decrement;
314
+ for syn in &mut self.segments[seg_i as usize].synapses {
315
+ if prev_active_cells[syn.presynaptic_cell as usize] {
316
+ syn.permanence = (syn.permanence - dec).max(0.0);
317
+ }
318
+ }
319
+ }
320
+ }
321
+ }
322
+ }
323
+
324
+ // ---- 4) Compute segment activity & predictive cells for NEXT step ----
325
+ // We have to use the *current* active_cells (just set above).
326
+ let mut next_active_segs: Vec<SegmentIdx> = Vec::new();
327
+ let mut next_matching_segs: Vec<SegmentIdx> = Vec::new();
328
+ for v in self.predictive_cells.iter_mut() { *v = false; }
329
+
330
+ let conn = self.cfg.connected_permanence;
331
+ let act_thr = self.cfg.activation_threshold;
332
+ let learn_thr = self.cfg.learning_threshold;
333
+
334
+ for (seg_i, seg) in self.segments.iter_mut().enumerate() {
335
+ let mut n_conn: u32 = 0;
336
+ let mut n_pot: u32 = 0;
337
+ for syn in &seg.synapses {
338
+ if self.active_cells[syn.presynaptic_cell as usize] {
339
+ n_pot += 1;
340
+ if syn.permanence >= conn { n_conn += 1; }
341
+ }
342
+ }
343
+ seg.num_active_connected = n_conn;
344
+ seg.num_active_potential = n_pot;
345
+ if n_conn >= act_thr {
346
+ next_active_segs.push(seg_i as SegmentIdx);
347
+ self.predictive_cells[seg.cell as usize] = true;
348
+ }
349
+ if n_pot >= learn_thr {
350
+ next_matching_segs.push(seg_i as SegmentIdx);
351
+ }
352
+ }
353
+ self.active_segments_prev = next_active_segs;
354
+ self.matching_segments_prev = next_matching_segs;
355
+
356
+ // Keep predictive_prev unused-guard; we no longer need it but
357
+ // retained to document intent.
358
+ let _ = predictive_prev;
359
+
360
+ // Anomaly.
361
+ if active_columns.is_empty() {
362
+ 0.0
363
+ } else {
364
+ (unpredicted_cols as f32) / (active_columns.len() as f32)
365
+ }
366
+ }
367
+
368
+ /// Reinforce synapses on `seg`: +inc if presynaptic is active last step,
369
+ /// -dec otherwise.
370
+ fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) {
371
+ let inc = self.cfg.permanence_increment;
372
+ let dec = self.cfg.permanence_decrement;
373
+ let seg = &mut self.segments[seg_i as usize];
374
+ seg.last_used_iteration = self.iter_count;
375
+ for syn in &mut seg.synapses {
376
+ if prev_active_cells[syn.presynaptic_cell as usize] {
377
+ syn.permanence = (syn.permanence + inc).min(1.0);
378
+ } else {
379
+ syn.permanence = (syn.permanence - dec).max(0.0);
380
+ }
381
+ }
382
+ }
383
+
384
+ /// Grow up to `max_new_synapse_count - current_potential` new synapses
385
+ /// from previous winner cells that are not already connected to this seg.
386
+ fn grow_synapses_on_segment(
387
+ &mut self,
388
+ seg_i: SegmentIdx,
389
+ prev_winner_cells: &[bool],
390
+ ) {
391
+ let initial_perm = self.cfg.initial_permanence;
392
+ let cap = self.cfg.max_synapses_per_segment;
393
+ let max_new = self.cfg.max_new_synapse_count;
394
+
395
+ // Gather candidate cells (prev winners not already presynaptic to this seg).
396
+ let already: Vec<CellIdx> = self.segments[seg_i as usize]
397
+ .synapses
398
+ .iter()
399
+ .map(|s| s.presynaptic_cell)
400
+ .collect();
401
+ let mut candidates: Vec<CellIdx> = Vec::new();
402
+ for (cell_i, &b) in prev_winner_cells.iter().enumerate() {
403
+ if b && !already.contains(&(cell_i as CellIdx)) {
404
+ candidates.push(cell_i as CellIdx);
405
+ }
406
+ }
407
+
408
+ // How many can we add?
409
+ let current_len = self.segments[seg_i as usize].synapses.len();
410
+ let room = cap.saturating_sub(current_len);
411
+ let mut to_add = max_new.min(candidates.len()).min(room);
412
+
413
+ // Random sample without replacement from candidates.
414
+ while to_add > 0 {
415
+ let idx = self.rng.gen_range(0..candidates.len());
416
+ let pre = candidates.swap_remove(idx);
417
+ self.segments[seg_i as usize].synapses.push(Synapse {
418
+ presynaptic_cell: pre,
419
+ permanence: initial_perm,
420
+ });
421
+ to_add -= 1;
422
+ }
423
+ }
424
+
425
+ fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx {
426
+ // Enforce per-cell segment cap by evicting least-recently-used segment
427
+ // if necessary.
428
+ let cell_segs = &mut self.cell_segments[cell as usize];
429
+ if cell_segs.len() >= self.cfg.max_segments_per_cell {
430
+ // Find LRU segment.
431
+ let (lru_pos, &lru_id) = cell_segs
432
+ .iter()
433
+ .enumerate()
434
+ .min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration)
435
+ .expect("cell_segs non-empty");
436
+ // Clear that segment in place and reuse its index.
437
+ self.segments[lru_id as usize].synapses.clear();
438
+ self.segments[lru_id as usize].num_active_connected = 0;
439
+ self.segments[lru_id as usize].num_active_potential = 0;
440
+ self.segments[lru_id as usize].last_used_iteration = self.iter_count;
441
+ // Keep at same position in cell_segs.
442
+ let _ = lru_pos;
443
+ return lru_id;
444
+ }
445
+
446
+ let new_id = self.segments.len() as SegmentIdx;
447
+ self.segments.push(Segment {
448
+ cell,
449
+ synapses: Vec::with_capacity(self.cfg.max_new_synapse_count),
450
+ num_active_connected: 0,
451
+ num_active_potential: 0,
452
+ last_used_iteration: self.iter_count,
453
+ });
454
+ cell_segs.push(new_id);
455
+ new_id
456
+ }
457
+
458
+ fn least_used_cell(&mut self, col: usize) -> CellIdx {
459
+ // Cell with the fewest segments; break ties randomly.
460
+ let mut min_segs = usize::MAX;
461
+ let mut candidates: Vec<CellIdx> = Vec::new();
462
+ for c in self.cells_in_col(col) {
463
+ let n = self.cell_segments[c as usize].len();
464
+ if n < min_segs {
465
+ min_segs = n;
466
+ candidates.clear();
467
+ candidates.push(c);
468
+ } else if n == min_segs {
469
+ candidates.push(c);
470
+ }
471
+ }
472
+ let idx = self.rng.gen_range(0..candidates.len());
473
+ candidates[idx]
474
+ }
475
+ }
476
+
477
+ // ---------------------------------------------------------------------------
478
+ // Tests
479
+ // ---------------------------------------------------------------------------
480
+
481
+ #[cfg(test)]
482
+ mod tests {
483
+ use super::*;
484
+ use crate::sp::{SpatialPooler, SpatialPoolerConfig};
485
+ use rand::Rng;
486
+ use rand::SeedableRng;
487
+ use rand_xoshiro::Xoshiro256PlusPlus;
488
+
489
+ #[test]
490
+ fn tm_learns_repeating_sequence() {
491
+ // Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down.
492
+ let cfg = SpatialPoolerConfig::default();
493
+ let mut sp = SpatialPooler::new(cfg, 123);
494
+ let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456);
495
+
496
+ // Build 3 fixed random SDRs of 2% sparsity.
497
+ let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
498
+ let input_bits = sp.cfg.input_bits;
499
+ let make_sdr = |rng: &mut Xoshiro256PlusPlus| {
500
+ let mut v = vec![false; input_bits];
501
+ let on = (0.02 * input_bits as f32) as usize;
502
+ let mut placed = 0;
503
+ while placed < on {
504
+ let i = rng.gen_range(0..input_bits);
505
+ if !v[i] {
506
+ v[i] = true;
507
+ placed += 1;
508
+ }
509
+ }
510
+ v
511
+ };
512
+ let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)];
513
+
514
+ // Warm up SP first so that columns are reliable for each symbol.
515
+ for _ in 0..200 {
516
+ for s in &seqs {
517
+ sp.compute(s, true);
518
+ }
519
+ }
520
+
521
+ // Reset TM so prediction state is clean.
522
+ tm.reset();
523
+
524
+ // Record anomaly over a window early and late.
525
+ let mut early_anoms: Vec<f32> = Vec::new();
526
+ let mut late_anoms: Vec<f32> = Vec::new();
527
+ for iter in 0..250 {
528
+ for s in &seqs {
529
+ let active = sp.compute(s, false);
530
+ let anomaly = tm.compute(&active, true);
531
+ if iter == 10 { early_anoms.push(anomaly); }
532
+ if iter == 249 { late_anoms.push(anomaly); }
533
+ }
534
+ }
535
+
536
+ let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32);
537
+ let early = mean(&early_anoms);
538
+ let late = mean(&late_anoms);
539
+ println!("early_anomaly={early}, late_anomaly={late}");
540
+ assert!(
541
+ late < 0.5 * early + 1e-6,
542
+ "late anomaly ({late}) should be < 0.5 * early anomaly ({early})"
543
+ );
544
+ }
545
+ }
overlay/hydra/__init__.py CHANGED
@@ -1,31 +1,37 @@
1
- """HYDRA training package.
2
-
3
- Thin facade re-exporting the public API used by train.py, the test suite,
4
- and external research scripts. Imports are lazy where possible to keep
5
- `import hydra` cheap (prepare.py and mamba-ssm are the heavy deps).
6
- """
7
-
8
- from hydra.config import PostSemClawConfig
9
- from hydra.engram import GPUEngram
10
- from hydra.model import PostSemClawModel, norm
11
- from hydra.optimizer import MuonAdamW, adamw_step_fused, muon_step_fused
12
-
13
- # config_from_dict is imported lazily (via attribute access on hydra.training)
14
- # to keep `import hydra` cheap; re-export here for convenience.
15
- def __getattr__(name: str):
16
- if name == "config_from_dict":
17
- from hydra.training import config_from_dict as _cfd
18
- return _cfd
19
- raise AttributeError(name)
20
-
21
-
22
- __all__ = [
23
- "PostSemClawConfig",
24
- "GPUEngram",
25
- "PostSemClawModel",
26
- "norm",
27
- "MuonAdamW",
28
- "adamw_step_fused",
29
- "muon_step_fused",
30
- "config_from_dict",
31
- ]
 
 
 
 
 
 
 
1
+ """HYDRA training package.
2
+
3
+ Thin facade re-exporting the public API used by train.py, the test suite,
4
+ and external research scripts. Imports are lazy where possible to keep
5
+ `import hydra` cheap (prepare.py and mamba-ssm are the heavy deps).
6
+ """
7
+
8
+ from hydra.config import PostSemClawConfig
9
+ from hydra.engram import GPUEngram
10
+ from hydra.optimizer import MuonAdamW, adamw_step_fused, muon_step_fused
11
+
12
+ # Heavy imports are resolved lazily so `import hydra` and `import hydra.hyena_block`
13
+ # keep working in local CPU/test environments that do not have the container-only
14
+ # mamba-ssm wheel stack installed.
15
+ def __getattr__(name: str):
16
+ if name == "PostSemClawModel":
17
+ from hydra.model import PostSemClawModel as _model
18
+ return _model
19
+ if name == "norm":
20
+ from hydra.model import norm as _norm
21
+ return _norm
22
+ if name == "config_from_dict":
23
+ from hydra.training import config_from_dict as _cfd
24
+ return _cfd
25
+ raise AttributeError(name)
26
+
27
+
28
+ __all__ = [
29
+ "PostSemClawConfig",
30
+ "GPUEngram",
31
+ "PostSemClawModel",
32
+ "norm",
33
+ "MuonAdamW",
34
+ "adamw_step_fused",
35
+ "muon_step_fused",
36
+ "config_from_dict",
37
+ ]
overlay/hydra/config.py CHANGED
@@ -1,220 +1,225 @@
1
- """HYDRA training configuration — dataclass + env-var constants.
2
-
3
- Extracted from the monolithic train.py as part of W1 modularization. All
4
- env-var reads and the PostSemClawConfig dataclass live here. The training
5
- body imports these constants; zero behavior change from the extraction.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import os
11
- from dataclasses import dataclass, field
12
-
13
-
14
- def _parse_hyena_layers_env() -> tuple[int, ...]:
15
- """Parse HYDRA_HYENA_LAYERS env var into a sorted tuple of layer indices.
16
-
17
- Used as the default_factory for PostSemClawConfig.hyena_layers so a fresh
18
- config construction reads the current env var, but once constructed the
19
- value is first-class and travels with checkpoints (see asdict(config) in
20
- save_ckpt). Ckpt-load sets the dataclass field explicitly, overriding the
21
- env-var default.
22
-
23
- Returns empty tuple when env var is unset/empty (byte-identical to
24
- pre-port behavior: no Hyena layers).
25
- """
26
- raw = os.environ.get("HYDRA_HYENA_LAYERS", "")
27
- if not raw:
28
- return ()
29
- return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
30
-
31
-
32
- def _parse_gdn_layers_env() -> tuple[int, ...]:
33
- """Parse HYDRA_GDN_LAYERS env var into a sorted tuple of layer indices.
34
-
35
- Same contract as _parse_hyena_layers_env: layers whose index is listed
36
- here use GatedDeltaNet (fla.layers.GatedDeltaNet) as a drop-in
37
- replacement for Mamba3. Empty tuple = no GDN layers (byte-identical
38
- to baseline).
39
- """
40
- raw = os.environ.get("HYDRA_GDN_LAYERS", "")
41
- if not raw:
42
- return ()
43
- return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
44
-
45
- # ---------------------------------------------------------------------------
46
- # CUDA env — set before importing torch in entry point. Kept here so any
47
- # module that `from hydra.config import ...` also benefits (import order is
48
- # top-down in Python, and train.py used to set these at module top).
49
- # ---------------------------------------------------------------------------
50
- os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
51
- if "/usr/local/cuda/bin" not in os.environ.get("PATH", ""):
52
- os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "")
53
- os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
54
-
55
-
56
- # ---------------------------------------------------------------------------
57
- # Model Configuration
58
- # ---------------------------------------------------------------------------
59
-
60
- @dataclass
61
- class PostSemClawConfig:
62
- """Full-architecture model config. Defaults reflect Phase-1 baseline;
63
- the training entry overrides d_model/n_layer/etc. from env vars."""
64
- # Sequence
65
- sequence_len: int = 2048
66
- vocab_size: int = 8192 # Must match prepare.py VOCAB_SIZE
67
-
68
- # Mamba-3 SSM
69
- n_layer: int = 6
70
- d_model: int = 384
71
- d_state: int = 64 # SSM state dimension
72
- headdim: int = 48 # head dimension for SSM
73
- n_heads: int = 8 # d_model // headdim
74
- expand: int = 2 # inner_dim = expand * d_model
75
-
76
- # Engram (conditional memory with Hebbian writes)
77
- engram_n_columns: int = 4096
78
- engram_key_dim: int = 64
79
- engram_layer_idx: int = 1 # which layer gets engram (0-indexed, mid-layer)
80
-
81
- # SemanticFoldingSDR (offline retina with STE; no-bypass, runs every step)
82
- sdr_n_bits: int = 16384 # retina width
83
- # Default 327 = 2% sparsity (Webber/Numenta canonical). Override with
84
- # HYDRA_SDR_TARGET_ACTIVE env var; value MUST match subsystems/sdr_retina.py
85
- # TARGET_ACTIVE (same env var is read there, so just setting it once works).
86
- sdr_target_active: int = int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327"))
87
- sdr_delta_rank: int = 32 # low-rank STE delta rank
88
- sdr_som_warmup: int = 500
89
- sdr_som_interval: int = 100
90
-
91
- # HTMLayer (Rust-backed, Hebbian; no-bypass, runs every step)
92
- htm_n_columns: int = 2048
93
- htm_cells_per_column: int = 32
94
-
95
- # Hyena supplement layer indices (sorted tuple). Defaults to the
96
- # HYDRA_HYENA_LAYERS env var at config-construction time, but once
97
- # persisted in a checkpoint the value is first-class and survives even
98
- # when the env var is unset at resume time. This fixes the ckpt-reload
99
- # crash path where a model trained with `HYDRA_HYENA_LAYERS=3,7` saves
100
- # HyenaBlock params but a fresh process without the env var would try
101
- # to build a pure-Mamba3 architecture and reject the state_dict as
102
- # `Missing/Unexpected key(s)`.
103
- hyena_layers: tuple[int, ...] = field(default_factory=_parse_hyena_layers_env)
104
-
105
- # GatedDeltaNet supplement layer indices (sorted tuple). Same semantics
106
- # as hyena_layers — a layer index listed here uses GDNBlock (fla-backed
107
- # Gated DeltaNet) instead of Mamba3. Selections are mutually exclusive
108
- # with hyena_layers at construction time (hyena wins on overlap; the
109
- # model loop checks hyena first).
110
- gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
111
-
112
- # Label smoothing + Z-loss
113
- label_smoothing: float = 0.0 # disabled: any smoothing hurts in 5-min budget
114
- z_loss_weight: float = 1e-4
115
-
116
-
117
- # ---------------------------------------------------------------------------
118
- # Hyperparameters (autoresearch agent modifies these via env vars)
119
- # ---------------------------------------------------------------------------
120
-
121
- # Model architecture
122
- D_MODEL = int(os.environ.get("HYDRA_D_MODEL", "256"))
123
- N_LAYER = int(os.environ.get("HYDRA_N_LAYER", "4"))
124
- D_STATE = int(os.environ.get("HYDRA_D_STATE", "64"))
125
- HEADDIM = int(os.environ.get("HYDRA_HEADDIM", "32"))
126
- N_HEADS = D_MODEL // HEADDIM
127
- EXPAND = int(os.environ.get("HYDRA_EXPAND", "2"))
128
-
129
- # Engram
130
- ENGRAM_N_COLUMNS = int(os.environ.get("HYDRA_ENGRAM_N_COLUMNS", "1024"))
131
- ENGRAM_KEY_DIM = 64
132
- ENGRAM_LAYER_IDX = int(os.environ.get("HYDRA_ENGRAM_LAYER_IDX", "1"))
133
-
134
- # Optimization
135
- DEVICE_BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
136
- TOTAL_BATCH_SIZE = int(os.environ.get("HYDRA_TOTAL_BATCH", "32768"))
137
- MATRIX_LR = float(os.environ.get("HYDRA_MATRIX_LR", "0.12"))
138
- EMBEDDING_LR = float(os.environ.get("HYDRA_EMBED_LR", "1.0"))
139
- UNEMBEDDING_LR = float(os.environ.get("HYDRA_UNEMBED_LR", "0.005"))
140
- SCALAR_LR = 0.5
141
- WEIGHT_DECAY = 0.01
142
- ADAM_BETAS = (0.9, 0.95)
143
- WARMUP_RATIO = 0.0
144
- WARMDOWN_RATIO = 0.5
145
- FINAL_LR_FRAC = float(os.environ.get("HYDRA_LR_MIN_MULT", "0.0"))
146
-
147
- # Runtime
148
- SEED = int(os.environ.get("HYDRA_SEED", "42"))
149
- # BF16 TFLOPS peak (RTX 3060=25.5, A100 SXM4=312, H100 SXM5=989)
150
- GPU_BF16_PEAK_FLOPS = float(os.environ.get("HYDRA_GPU_BF16_TFLOPS", "25.5")) * 1e12
151
-
152
- # Loss / inference knobs read by the model
153
- CE_CHUNK = int(os.environ.get("HYDRA_CE_CHUNK", "1024"))
154
- DROPOUT = float(os.environ.get("HYDRA_DROPOUT", "0.2"))
155
- FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
156
-
157
- # ---------------------------------------------------------------------------
158
- # Learnability knobs (all OFF by default — zero behavior change unless set)
159
- # ---------------------------------------------------------------------------
160
- # 1) Multi-Token Prediction (Llama-3 style). K=1 disables (next-1 only). K=4
161
- # adds 3 extra weight-tied heads; loss = mean of K position-shifted CEs.
162
- MTP_K = int(os.environ.get("HYDRA_MTP_K", "1"))
163
- # 2) Exponential Moving Average of model weights (decay=0.999). Saves an
164
- # additional latest_ema.pt at the end of training.
165
- USE_EMA = os.environ.get("HYDRA_USE_EMA", "0") == "1"
166
- EMA_DECAY = float(os.environ.get("HYDRA_EMA_DECAY", "0.999"))
167
- # 3) Gradient checkpointing on Mamba3 block forward. Trades ~30% compute for
168
- # ~40% activation memory savings lets you push B upward on a 3060.
169
- GRAD_CKPT = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
170
- # 4) Doc-separator masking in packed sequences: at every packed-BOS position
171
- # in the targets tensor, mask the loss (ignore_index=-1) so the model is
172
- # not forced to predict doc B from doc A's context.
173
- DOC_SEP_MASK = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
174
- # 5) Stop-gradient on HTM state (belt-and-braces: htm_rust already runs under
175
- # torch.no_grad() so the tensor returned has requires_grad=False; this
176
- # simply detaches explicitly to harden graph hygiene against future refactors).
177
- HTM_STOP_GRAD = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1"
178
- # 6) Output entropy penalty: loss += -lambda * H(softmax(logits)). Negative
179
- # entropy penalizes peaked distributions and breaks repetition loops.
180
- ENTROPY_PENALTY = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0"))
181
- # 7) Curriculum: first N optimizer steps use short seq_len, then switch to
182
- # full. 0 disables (no curriculum).
183
- CURRICULUM_SHORT_STEPS = int(os.environ.get("HYDRA_CURRICULUM_SHORT_STEPS", "0"))
184
- CURRICULUM_SHORT_SEQ_LEN = int(os.environ.get("HYDRA_CURRICULUM_SHORT_SEQ_LEN", "256"))
185
-
186
- # ---------------------------------------------------------------------------
187
- # Hyena supplement (additional block type for selected layer indices).
188
- # Hyena replaces Mamba3 at the specified layer indices while all other layers
189
- # remain Mamba3. Empty string (default) → no Hyena layers, byte-identical to
190
- # pre-port behavior.
191
- # HYDRA_HYENA_LAYERS "3,7" — comma-separated 0-indexed layer ids
192
- # HYDRA_HYENA_ORDER 2 — Hyena recurrence order (>= 2)
193
- # HYDRA_HYENA_FILTER_DIM 64 — implicit-filter MLP hidden width
194
- # Hyena reference: https://arxiv.org/pdf/2302.10866.pdf (HazyResearch/safari).
195
- # ---------------------------------------------------------------------------
196
- HYENA_LAYERS = os.environ.get("HYDRA_HYENA_LAYERS", "")
197
- HYENA_ORDER = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
198
- HYENA_FILTER_DIM = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
199
- # Filter-rfft cache modes (see subsystems/hyena_pure.py):
200
- # HYDRA_HYENA_FILTER_CACHE=1 — eval-only cache. Safe under torch.no_grad()
201
- # where PyTorch never saves intermediate tensors. Off by default.
202
- # HYDRA_HYENA_TRAIN_CACHE=1 — training-safe cache using a deferred
203
- # gradient pattern. Cuts the implicit filter MLP forward to ONCE per
204
- # optimizer step regardless of grad-accumulation factor. Requires the
205
- # training loop (see hydra/lightning_module.py::optimizer_step) to
206
- # call `model.flush_hyena_pending_grads()` before optimizer.step().
207
- # Off by default.
208
- HYENA_FILTER_CACHE = os.environ.get("HYDRA_HYENA_FILTER_CACHE", "0") == "1"
209
- HYENA_TRAIN_CACHE = os.environ.get("HYDRA_HYENA_TRAIN_CACHE", "0") == "1"
210
-
211
- # Factual eval knobs
212
- FACTUAL_SAMPLES = int(os.environ.get("HYDRA_FACTUAL_SAMPLES", "3"))
213
- FACTUAL_BATCH = int(os.environ.get("HYDRA_FACTUAL_BATCH", "32"))
214
- # F6 (partial): Full incremental SSM decode integration deferred — would require
215
- # threading mamba_ssm InferenceParams through PostSemClawModel.forward and all
216
- # auxiliary subsystems (HTM, SDR, Engram) which currently run full-sequence each
217
- # call. As a stopgap we reduce default from 16 -> 4 so the per-prompt cost is
218
- # quartered (each gen-tok does a full re-encode of ctx+k tokens). Override with
219
- # HYDRA_FACTUAL_GEN_TOKENS to restore prior behavior. See docs/OPTIMIZATION_PLAN.md.
220
- FACTUAL_GEN_TOKENS = int(os.environ.get("HYDRA_FACTUAL_GEN_TOKENS", "2"))
 
 
 
 
 
 
1
+ """HYDRA training configuration — dataclass + env-var constants.
2
+
3
+ Extracted from the monolithic train.py as part of W1 modularization. All
4
+ env-var reads and the PostSemClawConfig dataclass live here. The training
5
+ body imports these constants; zero behavior change from the extraction.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ from dataclasses import dataclass, field
12
+
13
+
14
+ def _parse_hyena_layers_env() -> tuple[int, ...]:
15
+ """Parse HYDRA_HYENA_LAYERS env var into a sorted tuple of layer indices.
16
+
17
+ Used as the default_factory for PostSemClawConfig.hyena_layers so a fresh
18
+ config construction reads the current env var, but once constructed the
19
+ value is first-class and travels with checkpoints (see asdict(config) in
20
+ save_ckpt). Ckpt-load sets the dataclass field explicitly, overriding the
21
+ env-var default.
22
+
23
+ Returns empty tuple when env var is unset/empty (byte-identical to
24
+ pre-port behavior: no Hyena layers).
25
+ """
26
+ raw = os.environ.get("HYDRA_HYENA_LAYERS", "")
27
+ if not raw:
28
+ return ()
29
+ return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
30
+
31
+
32
+ def _parse_gdn_layers_env() -> tuple[int, ...]:
33
+ """Parse HYDRA_GDN_LAYERS env var into a sorted tuple of layer indices.
34
+
35
+ Same contract as _parse_hyena_layers_env: layers whose index is listed
36
+ here use GatedDeltaNet (fla.layers.GatedDeltaNet) as a drop-in
37
+ replacement for Mamba3. Empty tuple = no GDN layers (byte-identical
38
+ to baseline).
39
+ """
40
+ raw = os.environ.get("HYDRA_GDN_LAYERS", "")
41
+ if not raw:
42
+ return ()
43
+ return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # CUDA env — set before importing torch in entry point. Kept here so any
47
+ # module that `from hydra.config import ...` also benefits (import order is
48
+ # top-down in Python, and train.py used to set these at module top).
49
+ # ---------------------------------------------------------------------------
50
+ os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
51
+ if "/usr/local/cuda/bin" not in os.environ.get("PATH", ""):
52
+ os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "")
53
+ os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Model Configuration
58
+ # ---------------------------------------------------------------------------
59
+
60
+ @dataclass
61
+ class PostSemClawConfig:
62
+ """Full-architecture model config. Defaults reflect Phase-1 baseline;
63
+ the training entry overrides d_model/n_layer/etc. from env vars."""
64
+ # Sequence
65
+ sequence_len: int = 2048
66
+ vocab_size: int = 8192 # Must match prepare.py VOCAB_SIZE
67
+
68
+ # Mamba-3 SSM
69
+ n_layer: int = 6
70
+ d_model: int = 384
71
+ d_state: int = 64 # SSM state dimension
72
+ headdim: int = 48 # head dimension for SSM
73
+ n_heads: int = 8 # d_model // headdim
74
+ expand: int = 2 # inner_dim = expand * d_model
75
+
76
+ # Engram (conditional memory with Hebbian writes)
77
+ engram_n_columns: int = 4096
78
+ engram_key_dim: int = 64
79
+ engram_layer_idx: int = 1 # which layer gets engram (0-indexed, mid-layer)
80
+
81
+ # SemanticFoldingSDR (offline retina with STE; no-bypass, runs every step)
82
+ sdr_n_bits: int = 16384 # retina width
83
+ # Default 327 = 2% sparsity (Webber/Numenta canonical). Override with
84
+ # HYDRA_SDR_TARGET_ACTIVE env var; value MUST match subsystems/sdr_retina.py
85
+ # TARGET_ACTIVE (same env var is read there, so just setting it once works).
86
+ sdr_target_active: int = int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327"))
87
+ sdr_delta_rank: int = 32 # low-rank STE delta rank
88
+ sdr_som_warmup: int = 500
89
+ sdr_som_interval: int = 100
90
+
91
+ # HTMLayer (Rust-backed, Hebbian; no-bypass, runs every step)
92
+ htm_n_columns: int = 2048
93
+ htm_cells_per_column: int = 32
94
+
95
+ # Hyena supplement layer indices (sorted tuple). Defaults to the
96
+ # HYDRA_HYENA_LAYERS env var at config-construction time, but once
97
+ # persisted in a checkpoint the value is first-class and survives even
98
+ # when the env var is unset at resume time. This fixes the ckpt-reload
99
+ # crash path where a model trained with `HYDRA_HYENA_LAYERS=3,7` saves
100
+ # HyenaBlock params but a fresh process without the env var would try
101
+ # to build a pure-Mamba3 architecture and reject the state_dict as
102
+ # `Missing/Unexpected key(s)`.
103
+ hyena_layers: tuple[int, ...] = field(default_factory=_parse_hyena_layers_env)
104
+
105
+ # GatedDeltaNet supplement layer indices (sorted tuple). Same semantics
106
+ # as hyena_layers — a layer index listed here uses GDNBlock (fla-backed
107
+ # Gated DeltaNet) instead of Mamba3. Selections are mutually exclusive
108
+ # with hyena_layers at construction time (hyena wins on overlap; the
109
+ # model loop checks hyena first).
110
+ gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
111
+
112
+ # Label smoothing + Z-loss
113
+ label_smoothing: float = field(default_factory=lambda: float(os.environ.get("HYDRA_LABEL_SMOOTHING", "0.0")))
114
+ z_loss_weight: float = field(default_factory=lambda: float(os.environ.get("HYDRA_Z_LOSS_WEIGHT", "1e-4")))
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Hyperparameters (autoresearch agent modifies these via env vars)
119
+ # ---------------------------------------------------------------------------
120
+
121
+ # Model architecture
122
+ D_MODEL = int(os.environ.get("HYDRA_D_MODEL", "256"))
123
+ N_LAYER = int(os.environ.get("HYDRA_N_LAYER", "4"))
124
+ D_STATE = int(os.environ.get("HYDRA_D_STATE", "64"))
125
+ HEADDIM = int(os.environ.get("HYDRA_HEADDIM", "32"))
126
+ N_HEADS = D_MODEL // HEADDIM
127
+ EXPAND = int(os.environ.get("HYDRA_EXPAND", "2"))
128
+
129
+ # Engram
130
+ ENGRAM_N_COLUMNS = int(os.environ.get("HYDRA_ENGRAM_N_COLUMNS", "1024"))
131
+ ENGRAM_KEY_DIM = 64
132
+ ENGRAM_LAYER_IDX = int(os.environ.get("HYDRA_ENGRAM_LAYER_IDX", "1"))
133
+
134
+ # Optimization
135
+ DEVICE_BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
136
+ TOTAL_BATCH_SIZE = int(os.environ.get("HYDRA_TOTAL_BATCH", "32768"))
137
+ MATRIX_LR = float(os.environ.get("HYDRA_MATRIX_LR", "0.12"))
138
+ EMBEDDING_LR = float(os.environ.get("HYDRA_EMBED_LR", "1.0"))
139
+ UNEMBEDDING_LR = float(os.environ.get("HYDRA_UNEMBED_LR", "0.005"))
140
+ # Scalar/vector params include Hyena implicit-filter vectors, norms, gate/bias
141
+ # terms, and SDR delta_u/delta_v. They are AdamW-scaled by d_model and can be
142
+ # the hidden instability path when the high-throughput HF recipe pushes a large
143
+ # device batch for hours. Keep the historical default, but make it controllable
144
+ # from launch scripts so cloud jobs can cool scalars without editing code.
145
+ SCALAR_LR = float(os.environ.get("HYDRA_SCALAR_LR", "0.5"))
146
+ WEIGHT_DECAY = float(os.environ.get("HYDRA_WEIGHT_DECAY", "0.01"))
147
+ ADAM_BETAS = (0.9, 0.95)
148
+ WARMUP_RATIO = float(os.environ.get("HYDRA_WARMUP_RATIO", "0.0"))
149
+ WARMDOWN_RATIO = 0.5
150
+ FINAL_LR_FRAC = float(os.environ.get("HYDRA_LR_MIN_MULT", "0.0"))
151
+
152
+ # Runtime
153
+ SEED = int(os.environ.get("HYDRA_SEED", "42"))
154
+ # BF16 TFLOPS peak (RTX 3060=25.5, A100 SXM4=312, H100 SXM5=989)
155
+ GPU_BF16_PEAK_FLOPS = float(os.environ.get("HYDRA_GPU_BF16_TFLOPS", "25.5")) * 1e12
156
+
157
+ # Loss / inference knobs read by the model
158
+ CE_CHUNK = int(os.environ.get("HYDRA_CE_CHUNK", "1024"))
159
+ DROPOUT = float(os.environ.get("HYDRA_DROPOUT", "0.2"))
160
+ FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # Learnability knobs (all OFF by default zero behavior change unless set)
164
+ # ---------------------------------------------------------------------------
165
+ # 1) Multi-Token Prediction (Llama-3 style). K=1 disables (next-1 only). K=4
166
+ # adds 3 extra weight-tied heads; loss = mean of K position-shifted CEs.
167
+ MTP_K = int(os.environ.get("HYDRA_MTP_K", "1"))
168
+ # 2) Exponential Moving Average of model weights (decay=0.999). Saves an
169
+ # additional latest_ema.pt at the end of training.
170
+ USE_EMA = os.environ.get("HYDRA_USE_EMA", "0") == "1"
171
+ EMA_DECAY = float(os.environ.get("HYDRA_EMA_DECAY", "0.999"))
172
+ # 3) Gradient checkpointing on Mamba3 block forward. Trades ~30% compute for
173
+ # ~40% activation memory savings lets you push B upward on a 3060.
174
+ GRAD_CKPT = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
175
+ # 4) Doc-separator masking in packed sequences: at every packed-BOS position
176
+ # in the targets tensor, mask the loss (ignore_index=-1) so the model is
177
+ # not forced to predict doc B from doc A's context.
178
+ DOC_SEP_MASK = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
179
+ # 5) Stop-gradient on HTM state (belt-and-braces: htm_rust already runs under
180
+ # torch.no_grad() so the tensor returned has requires_grad=False; this
181
+ # simply detaches explicitly to harden graph hygiene against future refactors).
182
+ HTM_STOP_GRAD = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1"
183
+ # 6) Output entropy penalty: loss += -lambda * H(softmax(logits)). Negative
184
+ # entropy penalizes peaked distributions and breaks repetition loops.
185
+ ENTROPY_PENALTY = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0"))
186
+ # 7) Curriculum: first N optimizer steps use short seq_len, then switch to
187
+ # full. 0 disables (no curriculum).
188
+ CURRICULUM_SHORT_STEPS = int(os.environ.get("HYDRA_CURRICULUM_SHORT_STEPS", "0"))
189
+ CURRICULUM_SHORT_SEQ_LEN = int(os.environ.get("HYDRA_CURRICULUM_SHORT_SEQ_LEN", "256"))
190
+
191
+ # ---------------------------------------------------------------------------
192
+ # Hyena supplement (additional block type for selected layer indices).
193
+ # Hyena replaces Mamba3 at the specified layer indices while all other layers
194
+ # remain Mamba3. Empty string (default) → no Hyena layers, byte-identical to
195
+ # pre-port behavior.
196
+ # HYDRA_HYENA_LAYERS "3,7" — comma-separated 0-indexed layer ids
197
+ # HYDRA_HYENA_ORDER 2 — Hyena recurrence order (>= 2)
198
+ # HYDRA_HYENA_FILTER_DIM 64 — implicit-filter MLP hidden width
199
+ # Hyena reference: https://arxiv.org/pdf/2302.10866.pdf (HazyResearch/safari).
200
+ # ---------------------------------------------------------------------------
201
+ HYENA_LAYERS = os.environ.get("HYDRA_HYENA_LAYERS", "")
202
+ HYENA_ORDER = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
203
+ HYENA_FILTER_DIM = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
204
+ # Filter-rfft cache modes (see subsystems/hyena_pure.py):
205
+ # HYDRA_HYENA_FILTER_CACHE=1 eval-only cache. Safe under torch.no_grad()
206
+ # where PyTorch never saves intermediate tensors. Off by default.
207
+ # HYDRA_HYENA_TRAIN_CACHE=1 — training-safe cache using a deferred
208
+ # gradient pattern. Cuts the implicit filter MLP forward to ONCE per
209
+ # optimizer step regardless of grad-accumulation factor. Requires the
210
+ # training loop (see hydra/lightning_module.py::optimizer_step) to
211
+ # call `model.flush_hyena_pending_grads()` before optimizer.step().
212
+ # Off by default.
213
+ HYENA_FILTER_CACHE = os.environ.get("HYDRA_HYENA_FILTER_CACHE", "0") == "1"
214
+ HYENA_TRAIN_CACHE = os.environ.get("HYDRA_HYENA_TRAIN_CACHE", "0") == "1"
215
+
216
+ # Factual eval knobs
217
+ FACTUAL_SAMPLES = int(os.environ.get("HYDRA_FACTUAL_SAMPLES", "3"))
218
+ FACTUAL_BATCH = int(os.environ.get("HYDRA_FACTUAL_BATCH", "32"))
219
+ # F6 (partial): Full incremental SSM decode integration deferred — would require
220
+ # threading mamba_ssm InferenceParams through PostSemClawModel.forward and all
221
+ # auxiliary subsystems (HTM, SDR, Engram) which currently run full-sequence each
222
+ # call. As a stopgap we reduce default from 16 -> 4 so the per-prompt cost is
223
+ # quartered (each gen-tok does a full re-encode of ctx+k tokens). Override with
224
+ # HYDRA_FACTUAL_GEN_TOKENS to restore prior behavior. See docs/OPTIMIZATION_PLAN.md.
225
+ FACTUAL_GEN_TOKENS = int(os.environ.get("HYDRA_FACTUAL_GEN_TOKENS", "2"))
overlay/hydra/data_module.py CHANGED
@@ -1,288 +1,288 @@
1
- """Lightning DataModule + IterableDataset for HYDRA pretraining.
2
-
3
- Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader
4
- with a standard multiprocessing DataLoader approach.
5
-
6
- Design:
7
- • IterableStreamDataset: each worker opens its own HF streams for the 7-way
8
- blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and
9
- yields one row per __next__.
10
- • HydraDataModule: wraps the dataset with a standard DataLoader using
11
- num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles
12
- device transfer.
13
- • Val stream: deterministic seed 12345, weights match training blend.
14
-
15
- The worker RNG is seeded per-worker so the weighted-sampling schedule is
16
- independent across workers (else all workers request the same config at
17
- the same step and prefetching serializes).
18
-
19
- Env vars (all preserved from prepare_nemotron):
20
- HYDRA_SEQ_LEN — sequence length T (default 512)
21
- HYDRA_BATCH_SIZE — batch size B (default 1) — passed through
22
- to DataLoader
23
- HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048)
24
- HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase
25
- HYDRA_USE_NEMOTRON — enables streaming path (else shard path)
26
- HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence
27
- HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend)
28
- HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2)
29
- HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4)
30
- HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing
31
- (default 1000)
32
- """
33
- from __future__ import annotations
34
-
35
- import os
36
- import random
37
- from typing import Iterator
38
-
39
- import numpy as np
40
- import torch
41
- import lightning as L
42
- from torch.utils.data import DataLoader, IterableDataset, get_worker_info
43
-
44
- import prepare as _prepare
45
- import prepare_nemotron as _p_nemo
46
- from prepare_nemotron import (
47
- FULL_BLEND_WEIGHTS,
48
- PHASE1_WEIGHTS,
49
- PHASE2_WEIGHTS,
50
- _BLEND_REGISTRY,
51
- _extract_text,
52
- _open_stream,
53
- )
54
-
55
-
56
- # ---------------------------------------------------------------------------
57
- # Worker-local weighted stream. A stripped version of prepare_nemotron's
58
- # _WeightedStream that is constructed inside each worker. Adds worker sharding:
59
- # when num_workers > 1 the RNG is seeded per-worker, so different workers
60
- # sample different config sequences and pull disjoint shard assignments from
61
- # HF's shuffle buffer.
62
- # ---------------------------------------------------------------------------
63
-
64
-
65
- class _WorkerWeightedStream:
66
- def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int):
67
- self.configs = list(weights.keys())
68
- self.weights = [weights[c] for c in self.configs]
69
- self.base_seed = base_seed
70
- self.worker_id = worker_id
71
- # Each worker opens its own HF streams. _open_stream returns an iter()
72
- # over a streaming dataset, with an internal shuffle buffer.
73
- self.streams = {c: _open_stream(c, "train") for c in self.configs}
74
- # Per-worker RNG so the config-choice trajectory is independent.
75
- self.rng = random.Random(base_seed + worker_id * 7919)
76
- self.epoch = 1
77
-
78
- # Lazy-init factual docs (once per worker). The main-process version
79
- # in prepare_nemotron._WeightedStream reads these on first __next__.
80
- self._factual_docs: list[str] | None = None
81
- self._factual_idx = 0
82
- self._inject_counter = 0
83
- inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
84
- self._inject_rate = inject_rate
85
- if inject_rate > 0:
86
- factual_path = os.path.join(
87
- os.path.dirname(os.path.abspath(_p_nemo.__file__)),
88
- "data", "factual", "facts.txt",
89
- )
90
- if os.path.exists(factual_path):
91
- with open(factual_path) as fh:
92
- self._factual_docs = fh.read().strip().split("\n")
93
-
94
- def _reopen(self, config: str) -> None:
95
- self.streams[config] = _open_stream(config, "train")
96
- self.epoch += 1
97
-
98
- def __iter__(self):
99
- return self
100
-
101
- def __next__(self) -> tuple[str, int]:
102
- # Factual injection (preserves prepare_nemotron cadence).
103
- if self._inject_rate > 0 and self._factual_docs:
104
- self._inject_counter += 1
105
- if self._inject_counter >= self._inject_rate:
106
- self._inject_counter = 0
107
- doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
108
- self._factual_idx += 1
109
- return doc, self.epoch
110
-
111
- config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
112
- try:
113
- row = next(self.streams[config])
114
- except StopIteration:
115
- self._reopen(config)
116
- row = next(self.streams[config])
117
- return _extract_text(row), self.epoch
118
-
119
-
120
- # ---------------------------------------------------------------------------
121
- # IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues.
122
- # Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks
123
- # rows into batches of shape (B, T+1) and sends them to the main process.
124
- # ---------------------------------------------------------------------------
125
-
126
-
127
- class IterableStreamDataset(IterableDataset):
128
- """Streams docs, tokenizes, packs into (T+1,) rows via best-fit.
129
-
130
- Each worker gets its own instance (via fork/spawn) and initializes its
131
- own HF streams + rustbpe tokenizer + factual injector. The tokenizer
132
- pickled blob is small (~1 MB) and thread-safe per tiktoken docs.
133
- """
134
-
135
- def __init__(
136
- self,
137
- split: str,
138
- seq_len: int,
139
- *,
140
- base_seed: int = 0,
141
- doc_buffer_size: int = 1000,
142
- tokenizer_batch: int = 128,
143
- ):
144
- super().__init__()
145
- assert split in ("train", "val"), split
146
- self.split = split
147
- self.seq_len = seq_len
148
- self.row_capacity = seq_len + 1
149
- self.base_seed = base_seed
150
- self.doc_buffer_size = doc_buffer_size
151
- self.tokenizer_batch = tokenizer_batch
152
-
153
- def _pick_weights(self) -> dict[str, float]:
154
- if self.split == "val":
155
- if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
156
- return FULL_BLEND_WEIGHTS
157
- return {"Nemotron-Pretraining-Multiple-Choice": 1.0}
158
- if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
159
- return FULL_BLEND_WEIGHTS
160
- phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower()
161
- return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
162
-
163
- def __iter__(self) -> Iterator[torch.Tensor]:
164
- info = get_worker_info()
165
- worker_id = 0 if info is None else info.id
166
-
167
- # Each worker builds its own tokenizer instance. tiktoken's Encoding
168
- # object is pickleable and the underlying C++ BPE is thread-safe;
169
- # per-worker instantiation avoids cross-process sharing headaches.
170
- tokenizer = _prepare.Tokenizer.from_directory()
171
- bos = tokenizer.get_bos_token_id()
172
-
173
- # Each worker gets its own weighted HF stream. Seed offset ensures
174
- # disjoint config-choice trajectories; HF's own shuffle buffer handles
175
- # shard randomization.
176
- val_seed = 12345 # deterministic val
177
- seed = val_seed if self.split == "val" else self.base_seed
178
- stream = _WorkerWeightedStream(
179
- self._pick_weights(), base_seed=seed, worker_id=worker_id,
180
- )
181
-
182
- row_capacity = self.row_capacity
183
- doc_buffer: list[list[int]] = []
184
- doc_batch_size = self.tokenizer_batch
185
-
186
- def refill_buffer() -> None:
187
- # Collect doc_batch_size text strings, then batch-tokenize.
188
- texts: list[str] = []
189
- for _ in range(doc_batch_size):
190
- text, _epoch = next(stream)
191
- if text:
192
- texts.append(text)
193
- if texts:
194
- token_lists = tokenizer.encode(texts, prepend=bos)
195
- doc_buffer.extend(token_lists)
196
-
197
- while True:
198
- pos = 0
199
- row = torch.empty(row_capacity, dtype=torch.long)
200
- while pos < row_capacity:
201
- while len(doc_buffer) < self.doc_buffer_size:
202
- refill_buffer()
203
-
204
- remaining = row_capacity - pos
205
-
206
- # Best-fit packing: largest doc that fully fits.
207
- best_idx = -1
208
- best_len = 0
209
- for i, doc in enumerate(doc_buffer):
210
- dlen = len(doc)
211
- if dlen <= remaining and dlen > best_len:
212
- best_idx = i
213
- best_len = dlen
214
-
215
- if best_idx >= 0:
216
- doc = doc_buffer.pop(best_idx)
217
- row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
218
- pos += len(doc)
219
- else:
220
- # No doc fits remaining space — crop shortest to fill.
221
- shortest_idx = min(
222
- range(len(doc_buffer)),
223
- key=lambda i: len(doc_buffer[i]),
224
- )
225
- doc = doc_buffer.pop(shortest_idx)
226
- row[pos : pos + remaining] = torch.tensor(
227
- doc[:remaining], dtype=torch.long,
228
- )
229
- pos += remaining
230
-
231
- yield row
232
-
233
-
234
- # ---------------------------------------------------------------------------
235
- # LightningDataModule
236
- # ---------------------------------------------------------------------------
237
-
238
-
239
- class HydraDataModule(L.LightningDataModule):
240
- def __init__(
241
- self,
242
- batch_size: int | None = None,
243
- seq_len: int | None = None,
244
- num_workers: int | None = None,
245
- prefetch_factor: int | None = None,
246
- ):
247
- super().__init__()
248
- self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
249
- self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512"))
250
- self.num_workers = (
251
- num_workers
252
- if num_workers is not None
253
- else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2"))
254
- )
255
- self.prefetch_factor = (
256
- prefetch_factor
257
- if prefetch_factor is not None
258
- else int(os.environ.get("HYDRA_DATA_PREFETCH", "4"))
259
- )
260
- self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000"))
261
-
262
- def _make_loader(self, split: str, seed: int) -> DataLoader:
263
- dataset = IterableStreamDataset(
264
- split=split,
265
- seq_len=self.seq_len,
266
- base_seed=seed,
267
- doc_buffer_size=self.doc_buffer,
268
- )
269
- # num_workers=0 → main-process iteration (useful for debugging). With
270
- # IterableDataset the DataLoader batches the rows into (B, T+1) via
271
- # default torch.stack-collate.
272
- kw: dict = dict(
273
- dataset=dataset,
274
- batch_size=self.batch_size,
275
- num_workers=self.num_workers,
276
- pin_memory=True,
277
- drop_last=True,
278
- )
279
- if self.num_workers > 0:
280
- kw["prefetch_factor"] = self.prefetch_factor
281
- kw["persistent_workers"] = True
282
- return DataLoader(**kw)
283
-
284
- def train_dataloader(self) -> DataLoader:
285
- return self._make_loader("train", seed=0)
286
-
287
- def val_dataloader(self) -> DataLoader:
288
- return self._make_loader("val", seed=12345)
 
1
+ """Lightning DataModule + IterableDataset for HYDRA pretraining.
2
+
3
+ Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader
4
+ with a standard multiprocessing DataLoader approach.
5
+
6
+ Design:
7
+ • IterableStreamDataset: each worker opens its own HF streams for the 7-way
8
+ blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and
9
+ yields one row per __next__.
10
+ • HydraDataModule: wraps the dataset with a standard DataLoader using
11
+ num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles
12
+ device transfer.
13
+ • Val stream: deterministic seed 12345, weights match training blend.
14
+
15
+ The worker RNG is seeded per-worker so the weighted-sampling schedule is
16
+ independent across workers (else all workers request the same config at
17
+ the same step and prefetching serializes).
18
+
19
+ Env vars (all preserved from prepare_nemotron):
20
+ HYDRA_SEQ_LEN — sequence length T (default 512)
21
+ HYDRA_BATCH_SIZE — batch size B (default 1) — passed through
22
+ to DataLoader
23
+ HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048)
24
+ HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase
25
+ HYDRA_USE_NEMOTRON — enables streaming path (else shard path)
26
+ HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence
27
+ HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend)
28
+ HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2)
29
+ HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4)
30
+ HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing
31
+ (default 1000)
32
+ """
33
+ from __future__ import annotations
34
+
35
+ import os
36
+ import random
37
+ from typing import Iterator
38
+
39
+ import numpy as np
40
+ import torch
41
+ import lightning as L
42
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
43
+
44
+ import prepare as _prepare
45
+ import prepare_nemotron as _p_nemo
46
+ from prepare_nemotron import (
47
+ FULL_BLEND_WEIGHTS,
48
+ PHASE1_WEIGHTS,
49
+ PHASE2_WEIGHTS,
50
+ _BLEND_REGISTRY,
51
+ _extract_text,
52
+ _open_stream,
53
+ )
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Worker-local weighted stream. A stripped version of prepare_nemotron's
58
+ # _WeightedStream that is constructed inside each worker. Adds worker sharding:
59
+ # when num_workers > 1 the RNG is seeded per-worker, so different workers
60
+ # sample different config sequences and pull disjoint shard assignments from
61
+ # HF's shuffle buffer.
62
+ # ---------------------------------------------------------------------------
63
+
64
+
65
+ class _WorkerWeightedStream:
66
+ def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int):
67
+ self.configs = list(weights.keys())
68
+ self.weights = [weights[c] for c in self.configs]
69
+ self.base_seed = base_seed
70
+ self.worker_id = worker_id
71
+ # Each worker opens its own HF streams. _open_stream returns an iter()
72
+ # over a streaming dataset, with an internal shuffle buffer.
73
+ self.streams = {c: _open_stream(c, "train") for c in self.configs}
74
+ # Per-worker RNG so the config-choice trajectory is independent.
75
+ self.rng = random.Random(base_seed + worker_id * 7919)
76
+ self.epoch = 1
77
+
78
+ # Lazy-init factual docs (once per worker). The main-process version
79
+ # in prepare_nemotron._WeightedStream reads these on first __next__.
80
+ self._factual_docs: list[str] | None = None
81
+ self._factual_idx = 0
82
+ self._inject_counter = 0
83
+ inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
84
+ self._inject_rate = inject_rate
85
+ if inject_rate > 0:
86
+ factual_path = os.path.join(
87
+ os.path.dirname(os.path.abspath(_p_nemo.__file__)),
88
+ "data", "factual", "facts.txt",
89
+ )
90
+ if os.path.exists(factual_path):
91
+ with open(factual_path) as fh:
92
+ self._factual_docs = fh.read().strip().split("\n")
93
+
94
+ def _reopen(self, config: str) -> None:
95
+ self.streams[config] = _open_stream(config, "train")
96
+ self.epoch += 1
97
+
98
+ def __iter__(self):
99
+ return self
100
+
101
+ def __next__(self) -> tuple[str, int]:
102
+ # Factual injection (preserves prepare_nemotron cadence).
103
+ if self._inject_rate > 0 and self._factual_docs:
104
+ self._inject_counter += 1
105
+ if self._inject_counter >= self._inject_rate:
106
+ self._inject_counter = 0
107
+ doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
108
+ self._factual_idx += 1
109
+ return doc, self.epoch
110
+
111
+ config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
112
+ try:
113
+ row = next(self.streams[config])
114
+ except StopIteration:
115
+ self._reopen(config)
116
+ row = next(self.streams[config])
117
+ return _extract_text(row), self.epoch
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues.
122
+ # Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks
123
+ # rows into batches of shape (B, T+1) and sends them to the main process.
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ class IterableStreamDataset(IterableDataset):
128
+ """Streams docs, tokenizes, packs into (T+1,) rows via best-fit.
129
+
130
+ Each worker gets its own instance (via fork/spawn) and initializes its
131
+ own HF streams + rustbpe tokenizer + factual injector. The tokenizer
132
+ pickled blob is small (~1 MB) and thread-safe per tiktoken docs.
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ split: str,
138
+ seq_len: int,
139
+ *,
140
+ base_seed: int = 0,
141
+ doc_buffer_size: int = 1000,
142
+ tokenizer_batch: int = 128,
143
+ ):
144
+ super().__init__()
145
+ assert split in ("train", "val"), split
146
+ self.split = split
147
+ self.seq_len = seq_len
148
+ self.row_capacity = seq_len + 1
149
+ self.base_seed = base_seed
150
+ self.doc_buffer_size = doc_buffer_size
151
+ self.tokenizer_batch = tokenizer_batch
152
+
153
+ def _pick_weights(self) -> dict[str, float]:
154
+ if self.split == "val":
155
+ if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
156
+ return FULL_BLEND_WEIGHTS
157
+ return {"Nemotron-Pretraining-Multiple-Choice": 1.0}
158
+ if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
159
+ return FULL_BLEND_WEIGHTS
160
+ phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower()
161
+ return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
162
+
163
+ def __iter__(self) -> Iterator[torch.Tensor]:
164
+ info = get_worker_info()
165
+ worker_id = 0 if info is None else info.id
166
+
167
+ # Each worker builds its own tokenizer instance. tiktoken's Encoding
168
+ # object is pickleable and the underlying C++ BPE is thread-safe;
169
+ # per-worker instantiation avoids cross-process sharing headaches.
170
+ tokenizer = _prepare.Tokenizer.from_directory()
171
+ bos = tokenizer.get_bos_token_id()
172
+
173
+ # Each worker gets its own weighted HF stream. Seed offset ensures
174
+ # disjoint config-choice trajectories; HF's own shuffle buffer handles
175
+ # shard randomization.
176
+ val_seed = 12345 # deterministic val
177
+ seed = val_seed if self.split == "val" else self.base_seed
178
+ stream = _WorkerWeightedStream(
179
+ self._pick_weights(), base_seed=seed, worker_id=worker_id,
180
+ )
181
+
182
+ row_capacity = self.row_capacity
183
+ doc_buffer: list[list[int]] = []
184
+ doc_batch_size = self.tokenizer_batch
185
+
186
+ def refill_buffer() -> None:
187
+ # Collect doc_batch_size text strings, then batch-tokenize.
188
+ texts: list[str] = []
189
+ for _ in range(doc_batch_size):
190
+ text, _epoch = next(stream)
191
+ if text:
192
+ texts.append(text)
193
+ if texts:
194
+ token_lists = tokenizer.encode(texts, prepend=bos)
195
+ doc_buffer.extend(token_lists)
196
+
197
+ while True:
198
+ pos = 0
199
+ row = torch.empty(row_capacity, dtype=torch.long)
200
+ while pos < row_capacity:
201
+ while len(doc_buffer) < self.doc_buffer_size:
202
+ refill_buffer()
203
+
204
+ remaining = row_capacity - pos
205
+
206
+ # Best-fit packing: largest doc that fully fits.
207
+ best_idx = -1
208
+ best_len = 0
209
+ for i, doc in enumerate(doc_buffer):
210
+ dlen = len(doc)
211
+ if dlen <= remaining and dlen > best_len:
212
+ best_idx = i
213
+ best_len = dlen
214
+
215
+ if best_idx >= 0:
216
+ doc = doc_buffer.pop(best_idx)
217
+ row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
218
+ pos += len(doc)
219
+ else:
220
+ # No doc fits remaining space — crop shortest to fill.
221
+ shortest_idx = min(
222
+ range(len(doc_buffer)),
223
+ key=lambda i: len(doc_buffer[i]),
224
+ )
225
+ doc = doc_buffer.pop(shortest_idx)
226
+ row[pos : pos + remaining] = torch.tensor(
227
+ doc[:remaining], dtype=torch.long,
228
+ )
229
+ pos += remaining
230
+
231
+ yield row
232
+
233
+
234
+ # ---------------------------------------------------------------------------
235
+ # LightningDataModule
236
+ # ---------------------------------------------------------------------------
237
+
238
+
239
+ class HydraDataModule(L.LightningDataModule):
240
+ def __init__(
241
+ self,
242
+ batch_size: int | None = None,
243
+ seq_len: int | None = None,
244
+ num_workers: int | None = None,
245
+ prefetch_factor: int | None = None,
246
+ ):
247
+ super().__init__()
248
+ self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
249
+ self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512"))
250
+ self.num_workers = (
251
+ num_workers
252
+ if num_workers is not None
253
+ else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2"))
254
+ )
255
+ self.prefetch_factor = (
256
+ prefetch_factor
257
+ if prefetch_factor is not None
258
+ else int(os.environ.get("HYDRA_DATA_PREFETCH", "4"))
259
+ )
260
+ self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000"))
261
+
262
+ def _make_loader(self, split: str, seed: int) -> DataLoader:
263
+ dataset = IterableStreamDataset(
264
+ split=split,
265
+ seq_len=self.seq_len,
266
+ base_seed=seed,
267
+ doc_buffer_size=self.doc_buffer,
268
+ )
269
+ # num_workers=0 → main-process iteration (useful for debugging). With
270
+ # IterableDataset the DataLoader batches the rows into (B, T+1) via
271
+ # default torch.stack-collate.
272
+ kw: dict = dict(
273
+ dataset=dataset,
274
+ batch_size=self.batch_size,
275
+ num_workers=self.num_workers,
276
+ pin_memory=True,
277
+ drop_last=True,
278
+ )
279
+ if self.num_workers > 0:
280
+ kw["prefetch_factor"] = self.prefetch_factor
281
+ kw["persistent_workers"] = True
282
+ return DataLoader(**kw)
283
+
284
+ def train_dataloader(self) -> DataLoader:
285
+ return self._make_loader("train", seed=0)
286
+
287
+ def val_dataloader(self) -> DataLoader:
288
+ return self._make_loader("val", seed=12345)
overlay/hydra/diffusion_loss.py CHANGED
@@ -1,236 +1,236 @@
1
- """MDLM Rao-Blackwellized Masked Diffusion Loss.
2
-
3
- Implements the masked-diffusion ELBO from:
4
- Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
5
- NeurIPS 2024, arXiv:2406.07524.
6
-
7
- Equations referenced:
8
- - Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t)
9
- - Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1)
10
- - RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (1/alpha_t) * CE(x_theta(x_t), x_0) ]
11
- where the expectation over masked positions.
12
-
13
- Key insight: the Rao-Blackwellized estimate replaces an average over all masks
14
- (exponential) by a closed-form weighted CE that applies weight 1/alpha_t only
15
- on the positions that were masked, and 0 on unmasked positions. This gives an
16
- unbiased estimator with lower variance than a naive Monte Carlo over mask
17
- patterns.
18
-
19
- Reference implementation cross-checked against:
20
- https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss)
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- from typing import Literal
26
-
27
- import torch
28
- import torch.nn.functional as F
29
-
30
-
31
- # Clamping weight keeps gradients finite while still up-weighting high-noise
32
- # positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
33
- # launch (2026-04-22): loss 26 → 42 → NaN in 13 steps under Muon lr=7e-3
34
- # because per-token CE × 1000 saturated the 100-unit FAIL guard. The MDLM
35
- # paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
36
- # (70× larger), so the weight clamp needs to compensate.
37
- #
38
- # Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
39
- # weighting entirely (flat masked-LM CE, no RB reweighting — simpler and
40
- # more stable, sacrifices the theoretical ELBO property).
41
- import os as _os
42
- _MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
43
- _MIN_ALPHA: float = 1.0 / _MAX_WEIGHT # so clamp(alpha, min=_MIN_ALPHA) gives 1/alpha <= _MAX_WEIGHT
44
-
45
-
46
- # ---------------------------------------------------------------------------
47
- # Public API
48
- # ---------------------------------------------------------------------------
49
-
50
- def mdlm_masked_forward_process(
51
- targets: torch.Tensor,
52
- mask_token_id: int,
53
- t: torch.Tensor | None = None,
54
- alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
55
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
- """MDLM forward (noising) process: mask tokens and compute RB weights.
57
-
58
- Args:
59
- targets: (B, T) int64 token ids — the clean sequence x_0.
60
- mask_token_id: The special token id used to represent a masked token.
61
- t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
62
- element. t=0 means fully clean; t=1 means fully masked.
63
- alpha_schedule: Noise schedule.
64
- "loglinear" (MDLM default): alpha_t = 1 - t
65
- "linear": identical formula — both are provided for completeness
66
- since the paper calls the 1-t schedule "log-linear" in the context
67
- of the ELBO derivation.
68
-
69
- Returns:
70
- x_t : (B, T) int64 — noised sequence; masked positions hold
71
- mask_token_id, unmasked positions equal targets.
72
- mask_positions: (B, T) bool — True where the token was masked.
73
- loss_weights : (B, T) float32 — RB weighting factor. On masked
74
- positions: 1/alpha_t (clamped to _MAX_WEIGHT). On
75
- unmasked positions: 0.0. Summing
76
- (CE * loss_weights * mask_positions).sum() / mask.sum()
77
- gives the per-sample RB-ELBO estimator.
78
- """
79
- B, T = targets.shape
80
- device = targets.device
81
- dtype = torch.float32
82
-
83
- # --- sample or validate t ---
84
- if t is None:
85
- # Uniform(0, 1) per batch element; avoid exactly 0 and 1.
86
- t = torch.rand(B, device=device, dtype=dtype)
87
- else:
88
- t = t.to(device=device, dtype=dtype)
89
- if t.shape != (B,):
90
- raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
91
- if (t < 0).any() or (t > 1).any():
92
- raise ValueError("t must be in [0, 1]")
93
-
94
- # --- noise schedule: alpha_t = probability that a token is NOT masked ---
95
- # Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
96
- # refers to "log-linear" because the schedule is linear in the *log* domain
97
- # of the forward process probability. We expose both names for clarity.
98
- if alpha_schedule in ("linear", "loglinear"):
99
- alpha_t = 1.0 - t # (B,) float, in [0, 1]
100
- else:
101
- raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")
102
-
103
- # --- per-token Bernoulli mask ---
104
- # alpha_t[:, None] broadcasts to (B, T).
105
- alpha_t_expanded = alpha_t[:, None] # (B, 1)
106
- # Bernoulli(1 - alpha_t) = 1 means "mask this token".
107
- # We sample independently per token, per batch element.
108
- rand = torch.rand(B, T, device=device, dtype=dtype)
109
- mask_positions = rand > alpha_t_expanded # (B, T) bool
110
- # True → masked position
111
- # False → unmasked (kept as original)
112
-
113
- # --- build x_t ---
114
- x_t = targets.clone()
115
- x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)
116
-
117
- # --- RB loss weights: 1/alpha_t on masked positions, 0 elsewhere ---
118
- # Clamp alpha_t so weights stay finite near t→1.
119
- safe_alpha = alpha_t.clamp(min=_MIN_ALPHA) # (B,)
120
- weight_per_sample = 1.0 / safe_alpha # (B,)
121
- # Broadcast to (B, T) and zero out unmasked positions.
122
- loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T)
123
- loss_weights = loss_weights * mask_positions.float()
124
-
125
- return x_t, mask_positions, loss_weights
126
-
127
-
128
- def mdlm_rb_loss(
129
- logits: torch.Tensor,
130
- targets: torch.Tensor,
131
- mask_positions: torch.Tensor,
132
- loss_weights: torch.Tensor,
133
- ignore_index: int = -100,
134
- ) -> torch.Tensor:
135
- """Rao-Blackwellized negative ELBO.
136
-
137
- Applies the MDLM loss: cross-entropy on masked positions only, weighted
138
- per-token by loss_weights, averaged over the batch.
139
-
140
- The formula (eq. 7-8 of arXiv:2406.07524):
141
- L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
142
- / max(sum_T(mask_i), 1) ]
143
-
144
- Args:
145
- logits : (B, T, V) raw logits. May be bf16; internally cast to
146
- float32 for CE computation.
147
- targets : (B, T) int64 true token ids (x_0).
148
- mask_positions: (B, T) bool — True = masked position.
149
- loss_weights : (B, T) float32 — 1/alpha_t on masked positions, 0 elsewhere.
150
- ignore_index : Passed to F.cross_entropy; positions with this label
151
- are excluded from the loss.
152
-
153
- Returns:
154
- Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
155
- """
156
- B, T, V = logits.shape
157
-
158
- # Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
159
- # logits but accumulates in float internally anyway. Being explicit avoids
160
- # silent precision surprises.
161
- logits_f = logits.float() # (B, T, V)
162
-
163
- # Build targets with ignore_index on UNmasked positions so CE only fires
164
- # where mask_positions is True. We also honour any pre-existing -100 values
165
- # (e.g. doc-separator masking upstream).
166
- targets_masked = torch.where(
167
- mask_positions & (targets != ignore_index),
168
- targets,
169
- torch.full_like(targets, ignore_index),
170
- )
171
-
172
- # Per-token CE; shape (B, T). Positions with ignore_index → 0 from CE.
173
- per_tok_ce = F.cross_entropy(
174
- logits_f.reshape(B * T, V),
175
- targets_masked.reshape(B * T),
176
- ignore_index=ignore_index,
177
- reduction="none",
178
- ).reshape(B, T) # (B, T) float32
179
-
180
- # Apply RB weight. loss_weights already has 0 on unmasked positions.
181
- weighted = per_tok_ce * loss_weights # (B, T)
182
-
183
- # Per-sample mean over masked positions, then average over batch.
184
- mask_f = mask_positions.float() # (B, T)
185
- per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,)
186
- per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,)
187
-
188
- return per_sample_loss.mean() # scalar float32
189
-
190
-
191
- def mdlm_loss(
192
- logits: torch.Tensor,
193
- targets: torch.Tensor,
194
- mask_token_id: int,
195
- t: torch.Tensor | None = None,
196
- alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
197
- ignore_index: int = -100,
198
- ) -> torch.Tensor:
199
- """Convenience wrapper: forward process + RB-ELBO in one call.
200
-
201
- Suitable for the common case where the caller has full-vocab logits and
202
- wants a drop-in replacement for a standard masked-LM CE loss.
203
-
204
- Args:
205
- logits : (B, T, V) raw logits.
206
- targets : (B, T) int64 clean token ids.
207
- mask_token_id : The MASK token id used to corrupt the input.
208
- t : Optional (B,) timestep in (0, 1). Sampled if None.
209
- alpha_schedule: "loglinear" (default) or "linear".
210
- ignore_index : Token id to ignore in the loss (e.g. padding).
211
-
212
- Returns:
213
- Scalar float32 MDLM RB-ELBO loss.
214
-
215
- Note on sampled-softmax / partial logits:
216
- If your model only computes logits for a subset of vocab positions
217
- (e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
218
- and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
219
- """
220
- x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
221
- targets=targets,
222
- mask_token_id=mask_token_id,
223
- t=t,
224
- alpha_schedule=alpha_schedule,
225
- )
226
- # x_t is produced for the model's input (not used by this convenience
227
- # wrapper since logits are already provided by the caller). In a real
228
- # training loop the caller feeds x_t into the model to get logits, THEN
229
- # calls this function. See the orchestrator wiring note in training.py.
230
- return mdlm_rb_loss(
231
- logits=logits,
232
- targets=targets,
233
- mask_positions=mask_positions,
234
- loss_weights=loss_weights,
235
- ignore_index=ignore_index,
236
- )
 
1
+ """MDLM Rao-Blackwellized Masked Diffusion Loss.
2
+
3
+ Implements the masked-diffusion ELBO from:
4
+ Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
5
+ NeurIPS 2024, arXiv:2406.07524.
6
+
7
+ Equations referenced:
8
+ - Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t)
9
+ - Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1)
10
+ - RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (1/alpha_t) * CE(x_theta(x_t), x_0) ]
11
+ where the expectation over masked positions.
12
+
13
+ Key insight: the Rao-Blackwellized estimate replaces an average over all masks
14
+ (exponential) by a closed-form weighted CE that applies weight 1/alpha_t only
15
+ on the positions that were masked, and 0 on unmasked positions. This gives an
16
+ unbiased estimator with lower variance than a naive Monte Carlo over mask
17
+ patterns.
18
+
19
+ Reference implementation cross-checked against:
20
+ https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from typing import Literal
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+
30
+
31
+ # Clamping weight keeps gradients finite while still up-weighting high-noise
32
+ # positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
33
+ # launch (2026-04-22): loss 26 → 42 → NaN in 13 steps under Muon lr=7e-3
34
+ # because per-token CE × 1000 saturated the 100-unit FAIL guard. The MDLM
35
+ # paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
36
+ # (70× larger), so the weight clamp needs to compensate.
37
+ #
38
+ # Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
39
+ # weighting entirely (flat masked-LM CE, no RB reweighting — simpler and
40
+ # more stable, sacrifices the theoretical ELBO property).
41
+ import os as _os
42
+ _MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
43
+ _MIN_ALPHA: float = 1.0 / _MAX_WEIGHT # so clamp(alpha, min=_MIN_ALPHA) gives 1/alpha <= _MAX_WEIGHT
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Public API
48
+ # ---------------------------------------------------------------------------
49
+
50
+ def mdlm_masked_forward_process(
51
+ targets: torch.Tensor,
52
+ mask_token_id: int,
53
+ t: torch.Tensor | None = None,
54
+ alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
55
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """MDLM forward (noising) process: mask tokens and compute RB weights.
57
+
58
+ Args:
59
+ targets: (B, T) int64 token ids — the clean sequence x_0.
60
+ mask_token_id: The special token id used to represent a masked token.
61
+ t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
62
+ element. t=0 means fully clean; t=1 means fully masked.
63
+ alpha_schedule: Noise schedule.
64
+ "loglinear" (MDLM default): alpha_t = 1 - t
65
+ "linear": identical formula — both are provided for completeness
66
+ since the paper calls the 1-t schedule "log-linear" in the context
67
+ of the ELBO derivation.
68
+
69
+ Returns:
70
+ x_t : (B, T) int64 — noised sequence; masked positions hold
71
+ mask_token_id, unmasked positions equal targets.
72
+ mask_positions: (B, T) bool — True where the token was masked.
73
+ loss_weights : (B, T) float32 — RB weighting factor. On masked
74
+ positions: 1/alpha_t (clamped to _MAX_WEIGHT). On
75
+ unmasked positions: 0.0. Summing
76
+ (CE * loss_weights * mask_positions).sum() / mask.sum()
77
+ gives the per-sample RB-ELBO estimator.
78
+ """
79
+ B, T = targets.shape
80
+ device = targets.device
81
+ dtype = torch.float32
82
+
83
+ # --- sample or validate t ---
84
+ if t is None:
85
+ # Uniform(0, 1) per batch element; avoid exactly 0 and 1.
86
+ t = torch.rand(B, device=device, dtype=dtype)
87
+ else:
88
+ t = t.to(device=device, dtype=dtype)
89
+ if t.shape != (B,):
90
+ raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
91
+ if (t < 0).any() or (t > 1).any():
92
+ raise ValueError("t must be in [0, 1]")
93
+
94
+ # --- noise schedule: alpha_t = probability that a token is NOT masked ---
95
+ # Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
96
+ # refers to "log-linear" because the schedule is linear in the *log* domain
97
+ # of the forward process probability. We expose both names for clarity.
98
+ if alpha_schedule in ("linear", "loglinear"):
99
+ alpha_t = 1.0 - t # (B,) float, in [0, 1]
100
+ else:
101
+ raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")
102
+
103
+ # --- per-token Bernoulli mask ---
104
+ # alpha_t[:, None] broadcasts to (B, T).
105
+ alpha_t_expanded = alpha_t[:, None] # (B, 1)
106
+ # Bernoulli(1 - alpha_t) = 1 means "mask this token".
107
+ # We sample independently per token, per batch element.
108
+ rand = torch.rand(B, T, device=device, dtype=dtype)
109
+ mask_positions = rand > alpha_t_expanded # (B, T) bool
110
+ # True → masked position
111
+ # False → unmasked (kept as original)
112
+
113
+ # --- build x_t ---
114
+ x_t = targets.clone()
115
+ x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)
116
+
117
+ # --- RB loss weights: 1/alpha_t on masked positions, 0 elsewhere ---
118
+ # Clamp alpha_t so weights stay finite near t→1.
119
+ safe_alpha = alpha_t.clamp(min=_MIN_ALPHA) # (B,)
120
+ weight_per_sample = 1.0 / safe_alpha # (B,)
121
+ # Broadcast to (B, T) and zero out unmasked positions.
122
+ loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T)
123
+ loss_weights = loss_weights * mask_positions.float()
124
+
125
+ return x_t, mask_positions, loss_weights
126
+
127
+
128
+ def mdlm_rb_loss(
129
+ logits: torch.Tensor,
130
+ targets: torch.Tensor,
131
+ mask_positions: torch.Tensor,
132
+ loss_weights: torch.Tensor,
133
+ ignore_index: int = -100,
134
+ ) -> torch.Tensor:
135
+ """Rao-Blackwellized negative ELBO.
136
+
137
+ Applies the MDLM loss: cross-entropy on masked positions only, weighted
138
+ per-token by loss_weights, averaged over the batch.
139
+
140
+ The formula (eq. 7-8 of arXiv:2406.07524):
141
+ L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
142
+ / max(sum_T(mask_i), 1) ]
143
+
144
+ Args:
145
+ logits : (B, T, V) raw logits. May be bf16; internally cast to
146
+ float32 for CE computation.
147
+ targets : (B, T) int64 true token ids (x_0).
148
+ mask_positions: (B, T) bool — True = masked position.
149
+ loss_weights : (B, T) float32 — 1/alpha_t on masked positions, 0 elsewhere.
150
+ ignore_index : Passed to F.cross_entropy; positions with this label
151
+ are excluded from the loss.
152
+
153
+ Returns:
154
+ Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
155
+ """
156
+ B, T, V = logits.shape
157
+
158
+ # Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
159
+ # logits but accumulates in float internally anyway. Being explicit avoids
160
+ # silent precision surprises.
161
+ logits_f = logits.float() # (B, T, V)
162
+
163
+ # Build targets with ignore_index on UNmasked positions so CE only fires
164
+ # where mask_positions is True. We also honour any pre-existing -100 values
165
+ # (e.g. doc-separator masking upstream).
166
+ targets_masked = torch.where(
167
+ mask_positions & (targets != ignore_index),
168
+ targets,
169
+ torch.full_like(targets, ignore_index),
170
+ )
171
+
172
+ # Per-token CE; shape (B, T). Positions with ignore_index → 0 from CE.
173
+ per_tok_ce = F.cross_entropy(
174
+ logits_f.reshape(B * T, V),
175
+ targets_masked.reshape(B * T),
176
+ ignore_index=ignore_index,
177
+ reduction="none",
178
+ ).reshape(B, T) # (B, T) float32
179
+
180
+ # Apply RB weight. loss_weights already has 0 on unmasked positions.
181
+ weighted = per_tok_ce * loss_weights # (B, T)
182
+
183
+ # Per-sample mean over masked positions, then average over batch.
184
+ mask_f = mask_positions.float() # (B, T)
185
+ per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,)
186
+ per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,)
187
+
188
+ return per_sample_loss.mean() # scalar float32
189
+
190
+
191
+ def mdlm_loss(
192
+ logits: torch.Tensor,
193
+ targets: torch.Tensor,
194
+ mask_token_id: int,
195
+ t: torch.Tensor | None = None,
196
+ alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
197
+ ignore_index: int = -100,
198
+ ) -> torch.Tensor:
199
+ """Convenience wrapper: forward process + RB-ELBO in one call.
200
+
201
+ Suitable for the common case where the caller has full-vocab logits and
202
+ wants a drop-in replacement for a standard masked-LM CE loss.
203
+
204
+ Args:
205
+ logits : (B, T, V) raw logits.
206
+ targets : (B, T) int64 clean token ids.
207
+ mask_token_id : The MASK token id used to corrupt the input.
208
+ t : Optional (B,) timestep in (0, 1). Sampled if None.
209
+ alpha_schedule: "loglinear" (default) or "linear".
210
+ ignore_index : Token id to ignore in the loss (e.g. padding).
211
+
212
+ Returns:
213
+ Scalar float32 MDLM RB-ELBO loss.
214
+
215
+ Note on sampled-softmax / partial logits:
216
+ If your model only computes logits for a subset of vocab positions
217
+ (e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
218
+ and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
219
+ """
220
+ x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
221
+ targets=targets,
222
+ mask_token_id=mask_token_id,
223
+ t=t,
224
+ alpha_schedule=alpha_schedule,
225
+ )
226
+ # x_t is produced for the model's input (not used by this convenience
227
+ # wrapper since logits are already provided by the caller). In a real
228
+ # training loop the caller feeds x_t into the model to get logits, THEN
229
+ # calls this function. See the orchestrator wiring note in training.py.
230
+ return mdlm_rb_loss(
231
+ logits=logits,
232
+ targets=targets,
233
+ mask_positions=mask_positions,
234
+ loss_weights=loss_weights,
235
+ ignore_index=ignore_index,
236
+ )
overlay/hydra/engram.py CHANGED
@@ -1,175 +1,160 @@
1
- """GPU Engram — Top-k Sparse Hopfield retrieval, scales to n_columns >= 32768.
2
-
3
- ## What changed (scatter-gather → top-k Hopfield)
4
-
5
- The original forward used `self.memory[indices]` (scatter-gather), which misses
6
- L2 cache at n_columns > 4096 and creates a hard tps ceiling.
7
-
8
- An earlier Hopfield implementation used `entmax15` for sparse attention, but
9
- entmax's internal `torch.sort` over the full n_columns dimension allocates
10
- ~1 GB scratch at (B*T=8192, n_columns=32768) and OOMs on a 6 GB card.
11
-
12
- This module replaces the sort-based entmax with **top-k softmax**, which is
13
- O(B*T*K) in memory and O(B*T*K * log n_columns) in compute (the top-k is
14
- radix-selection under the hood — not a full sort). Sparsity is still exact:
15
- only K columns have non-zero weight per (batch, position).
16
-
17
- ## Why this scales where entmax didn't
18
-
19
- - `scores = x @ memory.T` is (B, T, n_columns) 268 MB at bf16 with n_columns=32768.
20
- - `scores.topk(K)` allocates only (B, T, K) ~2 MB at K=64. No full sort.
21
- - `memory[topk_idx]` gathers (B, T, K, d_model) — ~32 MB at bf16. Gather is
22
- on the LAST axis of memory (columns), contiguous stride-1 rows, cache-friendly.
23
- - `retrieved = einsum(topk_w, selected_mem)` — ~4 MB. Final reduction.
24
-
25
- Peak working set well under 400 MB at any reasonable n_columns + K. The weights
26
- tensor is never densified (which would have been the (B, T, n_columns) killer).
27
-
28
- ## Gradient flow
29
-
30
- Both the topk gather and the einsum are autograd-tracked, so `self.memory`
31
- receives gradient from the LM loss (which the Hebbian scatter-gather path did
32
- not). `topk` indices are detached — gradient flows through `topk_vals` via the
33
- selected memory rows.
34
-
35
- ## Sparsity
36
-
37
- Exactly K columns have non-zero weight per position. Default K=64, tunable via
38
- HYDRA_ENGRAM_TOPK.
39
-
40
- ## token_ids argument
41
-
42
- Accepted for API compatibility with hydra/model.py; unused in retrieval. The
43
- optional Hebbian boost (hebbian_boost=True) uses the hash-indexed path for
44
- its EMA write only.
45
-
46
- ## Checkpoint compatibility
47
-
48
- `self.memory` shape (n_columns, d_model) is unchanged; existing .pt/.ckpt
49
- files load without migration.
50
- """
51
-
52
- from __future__ import annotations
53
-
54
- import os
55
-
56
- import torch
57
- import torch.nn as nn
58
-
59
-
60
- # Top-k width how many memory columns get non-zero weight per position.
61
- # Default 64 matches the entmax sparsity fraction we observed empirically
62
- # (~0.2% of 32768 columns == 64). HYDRA_ENGRAM_TOPK env var overrides.
63
- _ENGRAM_TOPK = int(os.environ.get("HYDRA_ENGRAM_TOPK", "64"))
64
-
65
-
66
- class GPUEngram(nn.Module):
67
- """GPU Engram: Top-k Sparse Hopfield retrieval.
68
-
69
- Args:
70
- d_model: Model dimension — must match the surrounding transformer.
71
- n_columns: Number of memory columns (key-value pairs). Safe up to
72
- n_columns = 65536 at d_model = 384 on a 6 GB card with
73
- B*T <= 8192.
74
- max_ngram: Retained for API compatibility; unused in retrieval.
75
- hebbian_boost: If True, also run a Hebbian EMA write on the memory bank
76
- during training. Default False — the top-k gradient path
77
- provides learning signal without this.
78
- """
79
-
80
- def __init__(
81
- self,
82
- d_model: int,
83
- n_columns: int = 1024,
84
- max_ngram: int = 3,
85
- hebbian_boost: bool = False,
86
- ) -> None:
87
- super().__init__()
88
- self.n_columns = n_columns
89
- self.max_ngram = max_ngram
90
- self.hebbian_boost = hebbian_boost
91
- # Shape unchanged from original — existing checkpoints load cleanly.
92
- self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
93
- self.gate = nn.Linear(d_model, 1, bias=True)
94
- nn.init.constant_(self.gate.bias, 0.0) # START OPEN
95
- # Clamp topk K to n_columns so topk doesn't error at small engram.
96
- self.topk_k = min(_ENGRAM_TOPK, n_columns)
97
- # Retained for any external code that reads these attrs.
98
- self.primes = [2654435761, 2246822519, 3266489917]
99
- self.hebbian_lr = 0.01
100
-
101
- # ------------------------------------------------------------------
102
- # _hash: retained for API/checkpoint compat; unused in retrieval path.
103
- # ------------------------------------------------------------------
104
-
105
- def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
106
- """N-gram hash → column index (Hebbian-write target only, not retrieval)."""
107
- B, T = token_ids.shape
108
- h = token_ids * self.primes[0]
109
- if T > 1:
110
- shifted1 = torch.roll(token_ids, 1, dims=1)
111
- shifted1[:, 0] = 0
112
- h = h ^ (shifted1 * self.primes[1])
113
- if T > 2:
114
- shifted2 = torch.roll(token_ids, 2, dims=1)
115
- shifted2[:, :2] = 0
116
- h = h ^ (shifted2 * self.primes[2])
117
- return h % self.n_columns
118
-
119
- # ------------------------------------------------------------------
120
- # forward
121
- # ------------------------------------------------------------------
122
-
123
- def forward(self, x: torch.Tensor, token_ids: torch.Tensor):
124
- """Top-k Hopfield retrieve + soft gate + residual.
125
-
126
- Args:
127
- x: (B, T, d_model) input activations.
128
- token_ids: (B, T) accepted for API compat; only used in the
129
- optional Hebbian boost path.
130
-
131
- Returns:
132
- (x + alpha * retrieved, hit_rate)
133
- - x + alpha * retrieved: (B, T, d_model)
134
- - hit_rate: scalar tensor fraction of gate values > 0.1
135
- """
136
- B, T, D = x.shape
137
-
138
- # ---- 1. Similarity scores (coalesced GEMM) ----------------------
139
- # scores[b, t, c] = dot(x[b,t], memory[c])
140
- scores = x @ self.memory.T # (B, T, n_columns)
141
-
142
- # ---- 2. Top-k sparse attention ----------------------------------
143
- # topk uses radix select, not a sort — O(n_columns) memory, not O(n_columns log n_columns).
144
- # Never materializes a dense (B, T, n_columns) weights tensor.
145
- topk_vals, topk_idx = scores.topk(self.topk_k, dim=-1) # (B, T, K), (B, T, K)
146
- topk_w = torch.softmax(topk_vals, dim=-1) # (B, T, K)
147
-
148
- # ---- 3. Gather selected memory rows -----------------------------
149
- # memory[topk_idx] is a gather along axis 0 of memory (n_columns, d_model).
150
- # Output shape (B, T, K, d_model) — K is small, so gather bandwidth is
151
- # O(B*T*K*d_model), independent of n_columns.
152
- selected_mem = self.memory[topk_idx] # (B, T, K, d_model)
153
-
154
- # ---- 4. Weighted sum → retrieved vector -------------------------
155
- retrieved = torch.einsum('btk,btkd->btd', topk_w, selected_mem) # (B, T, d_model)
156
-
157
- # ---- 5. Soft gate -----------------------------------------------
158
- alpha = torch.sigmoid(self.gate(x)) # (B, T, 1)
159
-
160
- # ---- 6. Optional Hebbian EMA write ------------------------------
161
- if self.training and self.hebbian_boost:
162
- with torch.no_grad():
163
- indices = self._hash(token_ids)
164
- flat_idx = indices.reshape(-1) # (B*T,)
165
- flat_x = x.detach().reshape(-1, D) # (B*T, d_model)
166
- mem_dtype = self.memory.data.dtype
167
- updates = (
168
- self.hebbian_lr * flat_x
169
- - self.hebbian_lr * self.memory.data[flat_idx]
170
- ).to(mem_dtype)
171
- self.memory.data.index_add_(0, flat_idx, updates)
172
-
173
- # ---- 7. Residual + hit_rate -------------------------------------
174
- hit_rate = (alpha.detach() > 0.1).float().mean()
175
- return x + alpha * retrieved, hit_rate
 
1
+ """GPU Engram — Top-k Sparse Hopfield retrieval with optional Cantor/SDR nerve constraint."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ _ENGRAM_TOPK = int(os.environ.get("HYDRA_ENGRAM_TOPK", "64"))
12
+
13
+
14
+ class GPUEngram(nn.Module):
15
+ """GPU Engram: Top-k Sparse Hopfield retrieval.
16
+
17
+ Default `routing_mode=flat` preserves the existing full-memory top-k path.
18
+ `cantor_sdr` constrains candidates to the current Cantor leaf shard and SDR
19
+ active offsets. `auto` only uses that local path when it is cheaper than the
20
+ full score matrix (`K * d_model < n_columns`).
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ d_model: int,
26
+ n_columns: int = 1024,
27
+ max_ngram: int = 3,
28
+ hebbian_boost: bool = False,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.n_columns = n_columns
32
+ self.max_ngram = max_ngram
33
+ self.hebbian_boost = hebbian_boost
34
+ self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
35
+ self.gate = nn.Linear(d_model, 1, bias=True)
36
+ nn.init.constant_(self.gate.bias, 0.0)
37
+ self.topk_k = min(_ENGRAM_TOPK, n_columns)
38
+ self.primes = [2654435761, 2246822519, 3266489917]
39
+ self.hebbian_lr = 0.01
40
+ self.routing_mode = os.environ.get("HYDRA_ENGRAM_ROUTING", "auto").lower()
41
+
42
+ def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
43
+ B, T = token_ids.shape
44
+ h = token_ids * self.primes[0]
45
+ if T > 1:
46
+ shifted1 = torch.roll(token_ids, 1, dims=1)
47
+ shifted1[:, 0] = 0
48
+ h = h ^ (shifted1 * self.primes[1])
49
+ if T > 2:
50
+ shifted2 = torch.roll(token_ids, 2, dims=1)
51
+ shifted2[:, :2] = 0
52
+ h = h ^ (shifted2 * self.primes[2])
53
+ return h % self.n_columns
54
+
55
+ def _validate_active_indices(self, sdr_active_indices: torch.Tensor, x: torch.Tensor) -> None:
56
+ if not torch.is_floating_point(sdr_active_indices) and sdr_active_indices.dtype != torch.bool:
57
+ pass
58
+ else:
59
+ raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
60
+ if sdr_active_indices.dim() not in (2, 3):
61
+ raise ValueError("compact active indices must have shape (B,T,K) or (B*T,K)")
62
+ # Dense SDR masks arrive with K ~= n_bits; compact buffers are small
63
+ # (retina target_active or RealityBridge l0_k). Refuse obviously dense
64
+ # masks so forced cantor_sdr cannot silently route 0/1 values as offsets.
65
+ if sdr_active_indices.shape[-1] > 1024 or sdr_active_indices.shape[-1] > self.n_columns:
66
+ raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
67
+
68
+ def _cantor_sdr_candidates(
69
+ self,
70
+ sdr_active_indices: torch.Tensor,
71
+ cantor_leaf_ids: torch.Tensor,
72
+ n_leaves: int,
73
+ ) -> torch.Tensor:
74
+ """Map SDR active offsets into each Cantor leaf's Engram column shard."""
75
+ self._validate_active_indices(sdr_active_indices, cantor_leaf_ids)
76
+ if sdr_active_indices.dim() == 2:
77
+ B, T = cantor_leaf_ids.shape
78
+ sdr_active_indices = sdr_active_indices.view(B, T, -1)
79
+ sdr = sdr_active_indices.to(device=cantor_leaf_ids.device, dtype=torch.long)
80
+ leaves = cantor_leaf_ids.to(dtype=torch.long).clamp(min=0, max=max(0, n_leaves - 1))
81
+ cols_per_leaf = max(1, self.n_columns // max(1, n_leaves))
82
+ offsets = sdr.remainder(cols_per_leaf)
83
+ base = leaves.unsqueeze(-1) * cols_per_leaf
84
+ return (base + offsets).clamp(max=self.n_columns - 1)
85
+
86
+ def _flat_retrieve(self, x: torch.Tensor) -> torch.Tensor:
87
+ scores = x @ self.memory.T
88
+ topk_vals, topk_idx = scores.topk(self.topk_k, dim=-1)
89
+ topk_w = torch.softmax(topk_vals, dim=-1)
90
+ selected_mem = self.memory[topk_idx]
91
+ return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
92
+
93
+ def _cantor_sdr_retrieve(
94
+ self,
95
+ x: torch.Tensor,
96
+ sdr_active_indices: torch.Tensor,
97
+ cantor_leaf_ids: torch.Tensor,
98
+ cantor_n_leaves: int,
99
+ ) -> torch.Tensor:
100
+ candidates = self._cantor_sdr_candidates(
101
+ sdr_active_indices,
102
+ cantor_leaf_ids,
103
+ n_leaves=cantor_n_leaves,
104
+ )
105
+ cand_mem = self.memory[candidates]
106
+ scores = torch.einsum('btd,btkd->btk', x, cand_mem)
107
+ k = min(self.topk_k, scores.shape[-1])
108
+ topk_vals, local_idx = scores.topk(k, dim=-1)
109
+ topk_w = torch.softmax(topk_vals, dim=-1)
110
+ global_idx = candidates.gather(-1, local_idx)
111
+ selected_mem = self.memory[global_idx]
112
+ return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
113
+
114
+ def forward(
115
+ self,
116
+ x: torch.Tensor,
117
+ token_ids: torch.Tensor,
118
+ sdr_active_indices: torch.Tensor | None = None,
119
+ cantor_leaf_ids: torch.Tensor | None = None,
120
+ cantor_n_leaves: int | None = None,
121
+ ):
122
+ B, T, D = x.shape
123
+ mode = self.routing_mode
124
+ use_cantor = (
125
+ mode in {"cantor_sdr", "auto"}
126
+ and sdr_active_indices is not None
127
+ and cantor_leaf_ids is not None
128
+ and cantor_n_leaves is not None
129
+ )
130
+ if mode == "auto" and use_cantor:
131
+ k_active = sdr_active_indices.shape[-1]
132
+ # Compare actual retrieval candidates against the full-memory scan.
133
+ # The previous `(k_active * D) < n_columns` check mixed candidate
134
+ # count with feature dimension, so d256/k64 fell back to flat
135
+ # retrieval even though Cantor/SDR scores only 64 candidates vs
136
+ # 8k-16k memory columns. That kept required subsystems active but
137
+ # spent tens of billions of extra MACs per forward.
138
+ use_cantor = k_active < self.n_columns
139
+
140
+ if use_cantor and mode in {"cantor_sdr", "auto"}:
141
+ retrieved = self._cantor_sdr_retrieve(x, sdr_active_indices, cantor_leaf_ids, cantor_n_leaves)
142
+ else:
143
+ retrieved = self._flat_retrieve(x)
144
+
145
+ alpha = torch.sigmoid(self.gate(x))
146
+
147
+ if self.training and self.hebbian_boost:
148
+ with torch.no_grad():
149
+ indices = self._hash(token_ids)
150
+ flat_idx = indices.reshape(-1)
151
+ flat_x = x.detach().reshape(-1, D)
152
+ mem_dtype = self.memory.data.dtype
153
+ updates = (
154
+ self.hebbian_lr * flat_x
155
+ - self.hebbian_lr * self.memory.data[flat_idx]
156
+ ).to(mem_dtype)
157
+ self.memory.data.index_add_(0, flat_idx, updates)
158
+
159
+ hit_rate = (alpha.detach() > 0.1).float().mean()
160
+ return x + alpha * retrieved, hit_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/hydra/eval.py CHANGED
@@ -1,217 +1,210 @@
1
- """Evaluation: factual probes + sampled factual English scoring.
2
-
3
- Extracted from train.py (W1 modularization). Semantics unchanged.
4
-
5
- Perf optimizations (eval_perf_fix):
6
- - Probe mode: single forward per prompt instead of autoregressive gen
7
- - Batch decode: all GPU work first, all CPU decode after
8
- - Batched factual probes: single padded forward instead of N sequential
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import os
14
- import re as _re
15
-
16
- import torch
17
-
18
- from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS
19
-
20
- # Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
21
- # the original autoregressive generation path.
22
- FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
23
-
24
- FACTUAL_EVAL = [
25
- # Hard factual recall — requires specific knowledge memorization
26
- ("The capital of France is", ["Paris", "paris"]),
27
- ("Water boils at", ["100", "boiling"]),
28
- ("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
29
- # Easier completions — common collocations / patterns the model may pick up
30
- ("Once upon a", ["time"]),
31
- ("Hello, my name", ["is", "'s"]),
32
- ("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
33
- ("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
34
- # Original hard ones kept for completeness
35
- ("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
36
- ("Two plus two equals", ["4", "four"]),
37
- ]
38
-
39
- _FACTUAL_PROBES = [
40
- "The capital of France is",
41
- "Water boils at",
42
- "The largest planet in our solar system is",
43
- "The speed of light is approximately",
44
- "Shakespeare wrote",
45
- ]
46
-
47
-
48
- def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
49
- """Top-5 next-token predictions for canonical factual prompts.
50
-
51
- Batched: pads all prompts into a single forward pass instead of N
52
- sequential passes.
53
- """
54
- print("\n--- Factual Probes ---")
55
- model.eval()
56
-
57
- # Process probes one at a time to avoid cooperative launch limit
58
- # (batched forward with B=len(probes) can exceed SM residency cap).
59
- for prompt_text in _FACTUAL_PROBES:
60
- ids = tokenizer.encode(prompt_text)
61
- x = torch.tensor([ids], device=device)
62
- with torch.no_grad(), autocast_ctx:
63
- logits = model(x)
64
- probs = torch.softmax(logits[0, -1].float(), dim=-1)
65
- top5 = torch.topk(probs, 5)
66
- completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
67
- probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
68
- print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
69
- print("--- End Factual Probes ---\n")
70
-
71
-
72
- # ---------------------------------------------------------------------------
73
- # Probe mode: single forward per prompt (Fix D)
74
- # ---------------------------------------------------------------------------
75
-
76
- def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
77
- """Fast probe mode: for each (prompt, answers), encode prompt + each answer
78
- candidate as a single sequence, do ONE forward pass, and check if the model's
79
- argmax at the last prompt token matches the first answer token.
80
-
81
- Falls back to checking top-K predictions to be generous (same as gen mode
82
- which samples multiple temperatures).
83
- """
84
- print("---")
85
- print("factual_english_samples: (probe mode)")
86
- model.eval()
87
- hits = 0
88
-
89
- with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
90
- for prompt, answers in FACTUAL_EVAL:
91
- prompt_ids = tokenizer.encode(prompt)
92
- prompt_len = len(prompt_ids)
93
- x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
94
- logits = model(x, targets=None)
95
- # logits shape: [1, seq_len, vocab] or [1, vocab]
96
- if logits.dim() == 3:
97
- last_logits = logits[0, -1, :]
98
- else:
99
- last_logits = logits[0]
100
-
101
- probs = torch.softmax(last_logits.float(), dim=-1)
102
- # Check top-K predictions (generous: K=20 to match multi-sample gen)
103
- top_k = min(20, probs.shape[-1])
104
- top_ids = torch.topk(probs, top_k).indices.tolist()
105
- top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
106
-
107
- answers_lower = [a.lower() for a in answers]
108
- any_hit = any(
109
- any(a in tok for a in answers_lower)
110
- for tok in top_tokens
111
- )
112
- if any_hit:
113
- hits += 1
114
-
115
- best_completion = tokenizer.decode([top_ids[0]])
116
- print(f" prompt: {prompt!r}")
117
- print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
118
- print(f" hit: {any_hit} (probe top-{top_k})")
119
-
120
- score = hits / len(FACTUAL_EVAL)
121
- print("---")
122
- print(f"factual_english_score: {score:.4f}")
123
- print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
124
- return score, hits, len(FACTUAL_EVAL)
125
-
126
-
127
- # ---------------------------------------------------------------------------
128
- # Gen mode: original autoregressive path (Fix F: batch decode)
129
- # ---------------------------------------------------------------------------
130
-
131
- def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
132
- """Original autoregressive generation path with batch decode optimization:
133
- all GPU work runs first, then all CPU decoding happens after."""
134
- print("---")
135
- print("factual_english_samples: (gen mode)")
136
- model.eval()
137
-
138
- num_samples = FACTUAL_SAMPLES
139
- batch = FACTUAL_BATCH
140
- gen_tokens = FACTUAL_GEN_TOKENS
141
- # Optional fast incremental decode path for recurrence-capable backbones.
142
- # If disabled, we preserve the original full-context re-forward behavior.
143
- incremental_decode = os.environ.get("HYDRA_FACTUAL_GEN_INCREMENTAL", "1") == "1"
144
- temps = [0.7, 0.9, 1.1]
145
- hits = 0
146
-
147
- with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
148
- for prompt, answers in FACTUAL_EVAL:
149
- ids = tokenizer.encode(prompt)
150
- answers_lower = [a.lower() for a in answers]
151
- # Collect all generated token sequences on GPU first
152
- all_rows: list[list[int]] = []
153
- samples_done = 0
154
- batch_idx = 0
155
- while samples_done < num_samples:
156
- b = min(batch, num_samples - samples_done)
157
- temp = temps[batch_idx % len(temps)]
158
- batch_idx += 1
159
- ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
160
- logits = model(ctx, targets=None)
161
- for _ in range(gen_tokens):
162
- next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
163
- probs = torch.softmax(next_logits.float() / temp, dim=-1)
164
- next_id = torch.multinomial(probs, num_samples=1)
165
- ctx = torch.cat([ctx, next_id], dim=1)
166
- if ctx.size(1) >= max_seq_len:
167
- break
168
- if incremental_decode:
169
- logits = model(ctx[:, -1:], targets=None)
170
- else:
171
- logits = model(ctx, targets=None)
172
- # Transfer to CPU in one shot, no per-row sync
173
- all_rows.extend(ctx.cpu().tolist())
174
- samples_done += b
175
-
176
- # CPU-side batch decode no GPU sync between decodes
177
- any_hit = False
178
- first_gen = None
179
- hit_gen = None
180
- for row in all_rows:
181
- generated = tokenizer.decode(row)
182
- continuation = generated[len(prompt):].strip()
183
- _words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
184
- hit = any(a in _words for a in answers_lower)
185
- if first_gen is None:
186
- first_gen = generated
187
- if hit:
188
- any_hit = True
189
- if hit_gen is None:
190
- hit_gen = generated
191
- if any_hit:
192
- hits += 1
193
- print(f" prompt: {prompt!r}")
194
- print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}")
195
- print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
196
- if hit_gen is not None and hit_gen != first_gen:
197
- print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}")
198
-
199
- score = hits / len(FACTUAL_EVAL)
200
- print("---")
201
- print(f"factual_english_score: {score:.4f}")
202
- print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
203
- return score, hits, len(FACTUAL_EVAL)
204
-
205
-
206
- # ---------------------------------------------------------------------------
207
- # Public entry point
208
- # ---------------------------------------------------------------------------
209
-
210
- def run_factual_english(model, tokenizer, max_seq_len: int):
211
- """Dispatch to probe (fast, default) or gen (original) mode.
212
-
213
- Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
214
- """
215
- if FACTUAL_MODE == "gen":
216
- return _run_factual_english_gen(model, tokenizer, max_seq_len)
217
- return _run_factual_english_probe(model, tokenizer, max_seq_len)
 
1
+ """Evaluation: factual probes + sampled factual English scoring.
2
+
3
+ Extracted from train.py (W1 modularization). Semantics unchanged.
4
+
5
+ Perf optimizations (eval_perf_fix):
6
+ - Probe mode: single forward per prompt instead of autoregressive gen
7
+ - Batch decode: all GPU work first, all CPU decode after
8
+ - Batched factual probes: single padded forward instead of N sequential
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import re as _re
15
+
16
+ import torch
17
+
18
+ from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS
19
+
20
+ # Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
21
+ # the original autoregressive generation path.
22
+ FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
23
+
24
+ FACTUAL_EVAL = [
25
+ # Hard factual recall — requires specific knowledge memorization
26
+ ("The capital of France is", ["Paris", "paris"]),
27
+ ("Water boils at", ["100", "boiling"]),
28
+ ("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
29
+ # Easier completions — common collocations / patterns the model may pick up
30
+ ("Once upon a", ["time"]),
31
+ ("Hello, my name", ["is", "'s"]),
32
+ ("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
33
+ ("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
34
+ # Original hard ones kept for completeness
35
+ ("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
36
+ ("Two plus two equals", ["4", "four"]),
37
+ ]
38
+
39
+ _FACTUAL_PROBES = [
40
+ "The capital of France is",
41
+ "Water boils at",
42
+ "The largest planet in our solar system is",
43
+ "The speed of light is approximately",
44
+ "Shakespeare wrote",
45
+ ]
46
+
47
+
48
+ def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
49
+ """Top-5 next-token predictions for canonical factual prompts.
50
+
51
+ Batched: pads all prompts into a single forward pass instead of N
52
+ sequential passes.
53
+ """
54
+ print("\n--- Factual Probes ---")
55
+ model.eval()
56
+
57
+ # Process probes one at a time to avoid cooperative launch limit
58
+ # (batched forward with B=len(probes) can exceed SM residency cap).
59
+ for prompt_text in _FACTUAL_PROBES:
60
+ ids = tokenizer.encode(prompt_text)
61
+ x = torch.tensor([ids], device=device)
62
+ with torch.no_grad(), autocast_ctx:
63
+ logits = model(x)
64
+ probs = torch.softmax(logits[0, -1].float(), dim=-1)
65
+ top5 = torch.topk(probs, 5)
66
+ completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
67
+ probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
68
+ print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
69
+ print("--- End Factual Probes ---\n")
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Probe mode: single forward per prompt (Fix D)
74
+ # ---------------------------------------------------------------------------
75
+
76
+ def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
77
+ """Fast probe mode: for each (prompt, answers), encode prompt + each answer
78
+ candidate as a single sequence, do ONE forward pass, and check if the model's
79
+ argmax at the last prompt token matches the first answer token.
80
+
81
+ Falls back to checking top-K predictions to be generous (same as gen mode
82
+ which samples multiple temperatures).
83
+ """
84
+ print("---")
85
+ print("factual_english_samples: (probe mode)")
86
+ model.eval()
87
+ hits = 0
88
+
89
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
90
+ for prompt, answers in FACTUAL_EVAL:
91
+ prompt_ids = tokenizer.encode(prompt)
92
+ prompt_len = len(prompt_ids)
93
+ x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
94
+ logits = model(x, targets=None)
95
+ # logits shape: [1, seq_len, vocab] or [1, vocab]
96
+ if logits.dim() == 3:
97
+ last_logits = logits[0, -1, :]
98
+ else:
99
+ last_logits = logits[0]
100
+
101
+ probs = torch.softmax(last_logits.float(), dim=-1)
102
+ # Check top-K predictions (generous: K=20 to match multi-sample gen)
103
+ top_k = min(20, probs.shape[-1])
104
+ top_ids = torch.topk(probs, top_k).indices.tolist()
105
+ top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
106
+
107
+ answers_lower = [a.lower() for a in answers]
108
+ any_hit = any(
109
+ any(a in tok for a in answers_lower)
110
+ for tok in top_tokens
111
+ )
112
+ if any_hit:
113
+ hits += 1
114
+
115
+ best_completion = tokenizer.decode([top_ids[0]])
116
+ print(f" prompt: {prompt!r}")
117
+ print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
118
+ print(f" hit: {any_hit} (probe top-{top_k})")
119
+
120
+ score = hits / len(FACTUAL_EVAL)
121
+ print("---")
122
+ print(f"factual_english_score: {score:.4f}")
123
+ print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
124
+ return score, hits, len(FACTUAL_EVAL)
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Gen mode: original autoregressive path (Fix F: batch decode)
129
+ # ---------------------------------------------------------------------------
130
+
131
+ def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
132
+ """Original autoregressive generation path with batch decode optimization:
133
+ all GPU work runs first, then all CPU decoding happens after."""
134
+ print("---")
135
+ print("factual_english_samples: (gen mode)")
136
+ model.eval()
137
+
138
+ num_samples = FACTUAL_SAMPLES
139
+ batch = FACTUAL_BATCH
140
+ gen_tokens = FACTUAL_GEN_TOKENS
141
+ temps = [0.7, 0.9, 1.1]
142
+ hits = 0
143
+
144
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
145
+ for prompt, answers in FACTUAL_EVAL:
146
+ ids = tokenizer.encode(prompt)
147
+ answers_lower = [a.lower() for a in answers]
148
+ # Collect all generated token sequences on GPU first
149
+ all_rows: list[list[int]] = []
150
+ samples_done = 0
151
+ batch_idx = 0
152
+ while samples_done < num_samples:
153
+ b = min(batch, num_samples - samples_done)
154
+ temp = temps[batch_idx % len(temps)]
155
+ batch_idx += 1
156
+ ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
157
+ for _ in range(gen_tokens):
158
+ logits = model(ctx, targets=None)
159
+ next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
160
+ probs = torch.softmax(next_logits.float() / temp, dim=-1)
161
+ next_id = torch.multinomial(probs, num_samples=1)
162
+ ctx = torch.cat([ctx, next_id], dim=1)
163
+ if ctx.size(1) >= max_seq_len:
164
+ break
165
+ # Transfer to CPU in one shot, no per-row sync
166
+ all_rows.extend(ctx.cpu().tolist())
167
+ samples_done += b
168
+
169
+ # CPU-side batch decode — no GPU sync between decodes
170
+ any_hit = False
171
+ first_gen = None
172
+ hit_gen = None
173
+ for row in all_rows:
174
+ generated = tokenizer.decode(row)
175
+ continuation = generated[len(prompt):].strip()
176
+ _words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
177
+ hit = any(a in _words for a in answers_lower)
178
+ if first_gen is None:
179
+ first_gen = generated
180
+ if hit:
181
+ any_hit = True
182
+ if hit_gen is None:
183
+ hit_gen = generated
184
+ if any_hit:
185
+ hits += 1
186
+ print(f" prompt: {prompt!r}")
187
+ print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}")
188
+ print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
189
+ if hit_gen is not None and hit_gen != first_gen:
190
+ print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}")
191
+
192
+ score = hits / len(FACTUAL_EVAL)
193
+ print("---")
194
+ print(f"factual_english_score: {score:.4f}")
195
+ print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
196
+ return score, hits, len(FACTUAL_EVAL)
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Public entry point
201
+ # ---------------------------------------------------------------------------
202
+
203
+ def run_factual_english(model, tokenizer, max_seq_len: int):
204
+ """Dispatch to probe (fast, default) or gen (original) mode.
205
+
206
+ Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
207
+ """
208
+ if FACTUAL_MODE == "gen":
209
+ return _run_factual_english_gen(model, tokenizer, max_seq_len)
210
+ return _run_factual_english_probe(model, tokenizer, max_seq_len)
 
 
 
 
 
 
 
overlay/hydra/gdn_block.py CHANGED
@@ -1,126 +1,126 @@
1
- """GDNBlock — Gated Delta Net block, drop-in shape-compatible with Mamba3Block and HyenaBlock.
2
-
3
- GatedDeltaNet (GDN) reference: arXiv:2412.06464 (ICLR 2025, NVLabs).
4
- Implementation: flash-linear-attention (fla) library, Triton kernels, sm86-compatible.
5
-
6
- Interface contract (MUST match how Mamba3/Hyena are called in hydra/model.py):
7
- block = GDNBlock(d_model, ...)
8
- y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
9
-
10
- The surrounding mHC layer does NOT pre-norm before calling this block (the
11
- raw hidden state is passed in); the block itself applies no input normalization,
12
- same as HyenaBlock. We return the raw operator output; the mHC layer adds it
13
- as a residual stream contribution.
14
-
15
- NO attention, NO softmax-over-sequence-dim. All state is stateless between
16
- .forward() calls by default (use_cache=False, past_key_values=None).
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- try:
22
- from fla.layers.gated_deltanet import GatedDeltaNet as _GatedDeltaNet
23
- except ImportError as _fla_err:
24
- raise ImportError(
25
- "flash-linear-attention (fla) is required for GDNBlock but could not be imported. "
26
- "Install it with:\n"
27
- " pip install flash-linear-attention\n"
28
- "or from source:\n"
29
- " pip install git+https://github.com/fla-org/flash-linear-attention.git\n"
30
- f"Original error: {_fla_err}"
31
- ) from _fla_err
32
-
33
- import torch
34
- import torch.nn as nn
35
-
36
-
37
- class GDNBlock(nn.Module):
38
- """Gated Delta Net block, drop-in shape-compatible with HYDRA's Mamba3Block and HyenaBlock.
39
-
40
- Wraps `fla.layers.GatedDeltaNet` with the same external API that
41
- `hydra.hyena_block.HyenaBlock` exposes:
42
-
43
- forward(x: Tensor[B, T, d_model]) -> Tensor[B, T, d_model]
44
-
45
- Internal GatedDeltaNet.forward returns a 3-tuple
46
- (hidden_states, attn_weights, past_key_values); we extract [0] and
47
- return only the hidden states, keeping the residual stream unchanged.
48
-
49
- GDN outperforms Mamba-2 on in-context retrieval benchmarks (MQAR, etc.)
50
- at equal or faster compute, making it a targeted fix for HYDRA's factual
51
- plateau.
52
-
53
- Parameter counts are deliberately kept within 2x of a Mamba3 block at the
54
- same d_model/n_heads to be drop-in affordable.
55
- """
56
-
57
- def __init__(
58
- self,
59
- d_model: int,
60
- n_heads: int = 6,
61
- mode: str = "chunk", # 'chunk' for training, 'fused_recurrent' for inference
62
- expand_v: float = 2.0, # value-projection expansion; controls KV memory
63
- use_short_conv: bool = True,
64
- conv_size: int = 4,
65
- ):
66
- super().__init__()
67
- self.d_model = d_model
68
- self.n_heads = n_heads
69
- self.mode = mode
70
-
71
- # head_dim must divide d_model. GDN uses separate q/k head_dim from v;
72
- # we set head_dim for q/k such that n_heads * head_dim == d_model.
73
- if d_model % n_heads != 0:
74
- raise ValueError(
75
- f"d_model={d_model} must be divisible by n_heads={n_heads} "
76
- "so that head_dim = d_model // n_heads is an integer."
77
- )
78
- head_dim = d_model // n_heads
79
-
80
- self.gdn = _GatedDeltaNet(
81
- hidden_size=d_model,
82
- expand_v=expand_v,
83
- head_dim=head_dim,
84
- num_heads=n_heads,
85
- mode=mode,
86
- use_gate=True, # gating is the key architectural feature of GDN
87
- use_short_conv=use_short_conv,
88
- conv_size=conv_size,
89
- layer_idx=None, # no KV-cache layer indexing; we manage state ourselves
90
- )
91
-
92
- # ------------------------------------------------------------------
93
- # Forward
94
- # ------------------------------------------------------------------
95
-
96
- def forward(self, x: torch.Tensor) -> torch.Tensor:
97
- """x: [B, T, d_model] -> y: [B, T, d_model].
98
-
99
- Passes through GatedDeltaNet with use_cache=False so no recurrent
100
- state leaks between independent forward() calls (important for
101
- gradient-accumulation loops and eval).
102
- """
103
- # GatedDeltaNet.forward signature:
104
- # (hidden_states, attention_mask=None, past_key_values=None,
105
- # use_cache=False, output_attentions=False)
106
- # Returns: tuple(hidden_states, attn_weights|None, past_kv|None)
107
- out, _, _ = self.gdn(
108
- hidden_states=x,
109
- attention_mask=None,
110
- past_key_values=None,
111
- use_cache=False,
112
- output_attentions=False,
113
- )
114
- return out
115
-
116
- # ------------------------------------------------------------------
117
- # API parity with HyenaBlock and Mamba3Block
118
- # ------------------------------------------------------------------
119
-
120
- def invalidate_caches(self) -> None:
121
- """No-op — GDNBlock holds no persistent filter cache.
122
-
123
- Provided for API parity with HyenaBlock, which invalidates its
124
- Hyena filter cache here. Calling this is always safe.
125
- """
126
- pass
 
1
+ """GDNBlock — Gated Delta Net block, drop-in shape-compatible with Mamba3Block and HyenaBlock.
2
+
3
+ GatedDeltaNet (GDN) reference: arXiv:2412.06464 (ICLR 2025, NVLabs).
4
+ Implementation: flash-linear-attention (fla) library, Triton kernels, sm86-compatible.
5
+
6
+ Interface contract (MUST match how Mamba3/Hyena are called in hydra/model.py):
7
+ block = GDNBlock(d_model, ...)
8
+ y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
9
+
10
+ The surrounding mHC layer does NOT pre-norm before calling this block (the
11
+ raw hidden state is passed in); the block itself applies no input normalization,
12
+ same as HyenaBlock. We return the raw operator output; the mHC layer adds it
13
+ as a residual stream contribution.
14
+
15
+ NO attention, NO softmax-over-sequence-dim. All state is stateless between
16
+ .forward() calls by default (use_cache=False, past_key_values=None).
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ try:
22
+ from fla.layers.gated_deltanet import GatedDeltaNet as _GatedDeltaNet
23
+ except ImportError as _fla_err:
24
+ raise ImportError(
25
+ "flash-linear-attention (fla) is required for GDNBlock but could not be imported. "
26
+ "Install it with:\n"
27
+ " pip install flash-linear-attention\n"
28
+ "or from source:\n"
29
+ " pip install git+https://github.com/fla-org/flash-linear-attention.git\n"
30
+ f"Original error: {_fla_err}"
31
+ ) from _fla_err
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+
36
+
37
+ class GDNBlock(nn.Module):
38
+ """Gated Delta Net block, drop-in shape-compatible with HYDRA's Mamba3Block and HyenaBlock.
39
+
40
+ Wraps `fla.layers.GatedDeltaNet` with the same external API that
41
+ `hydra.hyena_block.HyenaBlock` exposes:
42
+
43
+ forward(x: Tensor[B, T, d_model]) -> Tensor[B, T, d_model]
44
+
45
+ Internal GatedDeltaNet.forward returns a 3-tuple
46
+ (hidden_states, attn_weights, past_key_values); we extract [0] and
47
+ return only the hidden states, keeping the residual stream unchanged.
48
+
49
+ GDN outperforms Mamba-2 on in-context retrieval benchmarks (MQAR, etc.)
50
+ at equal or faster compute, making it a targeted fix for HYDRA's factual
51
+ plateau.
52
+
53
+ Parameter counts are deliberately kept within 2x of a Mamba3 block at the
54
+ same d_model/n_heads to be drop-in affordable.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ d_model: int,
60
+ n_heads: int = 6,
61
+ mode: str = "chunk", # 'chunk' for training, 'fused_recurrent' for inference
62
+ expand_v: float = 2.0, # value-projection expansion; controls KV memory
63
+ use_short_conv: bool = True,
64
+ conv_size: int = 4,
65
+ ):
66
+ super().__init__()
67
+ self.d_model = d_model
68
+ self.n_heads = n_heads
69
+ self.mode = mode
70
+
71
+ # head_dim must divide d_model. GDN uses separate q/k head_dim from v;
72
+ # we set head_dim for q/k such that n_heads * head_dim == d_model.
73
+ if d_model % n_heads != 0:
74
+ raise ValueError(
75
+ f"d_model={d_model} must be divisible by n_heads={n_heads} "
76
+ "so that head_dim = d_model // n_heads is an integer."
77
+ )
78
+ head_dim = d_model // n_heads
79
+
80
+ self.gdn = _GatedDeltaNet(
81
+ hidden_size=d_model,
82
+ expand_v=expand_v,
83
+ head_dim=head_dim,
84
+ num_heads=n_heads,
85
+ mode=mode,
86
+ use_gate=True, # gating is the key architectural feature of GDN
87
+ use_short_conv=use_short_conv,
88
+ conv_size=conv_size,
89
+ layer_idx=None, # no KV-cache layer indexing; we manage state ourselves
90
+ )
91
+
92
+ # ------------------------------------------------------------------
93
+ # Forward
94
+ # ------------------------------------------------------------------
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ """x: [B, T, d_model] -> y: [B, T, d_model].
98
+
99
+ Passes through GatedDeltaNet with use_cache=False so no recurrent
100
+ state leaks between independent forward() calls (important for
101
+ gradient-accumulation loops and eval).
102
+ """
103
+ # GatedDeltaNet.forward signature:
104
+ # (hidden_states, attention_mask=None, past_key_values=None,
105
+ # use_cache=False, output_attentions=False)
106
+ # Returns: tuple(hidden_states, attn_weights|None, past_kv|None)
107
+ out, _, _ = self.gdn(
108
+ hidden_states=x,
109
+ attention_mask=None,
110
+ past_key_values=None,
111
+ use_cache=False,
112
+ output_attentions=False,
113
+ )
114
+ return out
115
+
116
+ # ------------------------------------------------------------------
117
+ # API parity with HyenaBlock and Mamba3Block
118
+ # ------------------------------------------------------------------
119
+
120
+ def invalidate_caches(self) -> None:
121
+ """No-op — GDNBlock holds no persistent filter cache.
122
+
123
+ Provided for API parity with HyenaBlock, which invalidates its
124
+ Hyena filter cache here. Calling this is always safe.
125
+ """
126
+ pass
overlay/hydra/hyena_block.py CHANGED
@@ -1,68 +1,68 @@
1
- """HyenaBlock — drop-in block for HYDRA, supplement to Mamba3.
2
-
3
- Wraps `subsystems.hyena_pure.HyenaOperator` with a pre-norm + residual scheme
4
- consistent with how the mHC stack wraps Mamba3 in `hydra/model.py`.
5
-
6
- Interface contract (MUST match how Mamba3 is called in model.py):
7
- block = HyenaBlock(d_model, seq_len)
8
- y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
9
-
10
- The surrounding mHC layer does the pre-norm (`norm(h)`) BEFORE calling the
11
- block, so the block itself should NOT re-normalize at input — same as Mamba3
12
- in the current model. We return the raw operator output; the mHC layer then
13
- adds it as a residual stream contribution.
14
-
15
- NO attention, NO softmax-over-sequence-dim, NO KV-cache. All forbidden
16
- imports enumerated in tests/test_hyena.py (test #7) are absent.
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- import os
22
-
23
- import torch
24
- import torch.nn as nn
25
-
26
- from subsystems.hyena_pure import HyenaOperator
27
-
28
-
29
- class HyenaBlock(nn.Module):
30
- """Single Hyena block, shape-compatible with Mamba3 in HYDRA."""
31
-
32
- def __init__(
33
- self,
34
- d_model: int,
35
- seq_len: int,
36
- order: int | None = None,
37
- filter_order: int | None = None,
38
- dropout: float = 0.0,
39
- filter_dropout: float = 0.0,
40
- short_filter_order: int = 3,
41
- activation: str = "id",
42
- ):
43
- super().__init__()
44
- # Env overrides (documented in hydra/config.py).
45
- if order is None:
46
- order = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
47
- if filter_order is None:
48
- filter_order = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
49
-
50
- self.d_model = d_model
51
- self.seq_len = seq_len
52
- self.order = order
53
- self.filter_order = filter_order
54
-
55
- self.operator = HyenaOperator(
56
- d_model=d_model,
57
- l_max=seq_len,
58
- order=order,
59
- filter_order=filter_order,
60
- dropout=dropout,
61
- filter_dropout=filter_dropout,
62
- short_filter_order=short_filter_order,
63
- activation=activation,
64
- )
65
-
66
- def forward(self, x: torch.Tensor) -> torch.Tensor:
67
- """x: [B, T, d_model] -> y: [B, T, d_model]."""
68
- return self.operator(x)
 
1
+ """HyenaBlock — drop-in block for HYDRA, supplement to Mamba3.
2
+
3
+ Wraps `subsystems.hyena_pure.HyenaOperator` with a pre-norm + residual scheme
4
+ consistent with how the mHC stack wraps Mamba3 in `hydra/model.py`.
5
+
6
+ Interface contract (MUST match how Mamba3 is called in model.py):
7
+ block = HyenaBlock(d_model, seq_len)
8
+ y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
9
+
10
+ The surrounding mHC layer does the pre-norm (`norm(h)`) BEFORE calling the
11
+ block, so the block itself should NOT re-normalize at input — same as Mamba3
12
+ in the current model. We return the raw operator output; the mHC layer then
13
+ adds it as a residual stream contribution.
14
+
15
+ NO attention, NO softmax-over-sequence-dim, NO KV-cache. All forbidden
16
+ imports enumerated in tests/test_hyena.py (test #7) are absent.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import os
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from subsystems.hyena_pure import HyenaOperator
27
+
28
+
29
+ class HyenaBlock(nn.Module):
30
+ """Single Hyena block, shape-compatible with Mamba3 in HYDRA."""
31
+
32
+ def __init__(
33
+ self,
34
+ d_model: int,
35
+ seq_len: int,
36
+ order: int | None = None,
37
+ filter_order: int | None = None,
38
+ dropout: float = 0.0,
39
+ filter_dropout: float = 0.0,
40
+ short_filter_order: int = 3,
41
+ activation: str = "id",
42
+ ):
43
+ super().__init__()
44
+ # Env overrides (documented in hydra/config.py).
45
+ if order is None:
46
+ order = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
47
+ if filter_order is None:
48
+ filter_order = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
49
+
50
+ self.d_model = d_model
51
+ self.seq_len = seq_len
52
+ self.order = order
53
+ self.filter_order = filter_order
54
+
55
+ self.operator = HyenaOperator(
56
+ d_model=d_model,
57
+ l_max=seq_len,
58
+ order=order,
59
+ filter_order=filter_order,
60
+ dropout=dropout,
61
+ filter_dropout=filter_dropout,
62
+ short_filter_order=short_filter_order,
63
+ activation=activation,
64
+ )
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ """x: [B, T, d_model] -> y: [B, T, d_model]."""
68
+ return self.operator(x)
overlay/hydra/lightning_module.py CHANGED
@@ -1,326 +1,326 @@
1
- """LightningModule wrapping PostSemClawModel.
2
-
3
- Thin adapter. The model and the MuonAdamW optimizer are unchanged. This
4
- module implements:
5
-
6
- • configure_optimizers — returns the existing MuonAdamW (subclass of
7
- torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts
8
- this directly.
9
- • training_step — splits (B, T+1) batches into (x, y), forwards through
10
- the model, logs loss / bpb / tps / mfu / vram. Preserves the
11
- sampled-softmax path inside PostSemClawModel (no changes there).
12
- • optimizer_step — before each step we update LR + muon momentum + WD
13
- using the same time-progress schedule as hydra/training.py
14
- (get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning
15
- handles grad accumulation via Trainer(accumulate_grad_batches=N).
16
-
17
- The SDR SOM update and Hestia QAT snap are called at the same cadence as
18
- the legacy loop, but inline on the main thread (Lightning provides its own
19
- callbacks for async work if we need to extract them later — keeping it
20
- simple for now).
21
-
22
- Env vars respected:
23
- HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule
24
- and as Trainer max_time
25
- HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100)
26
- HYDRA_BATCH_SIZE — device batch size (for throughput calc)
27
- HYDRA_SEQ_LEN — sequence length (for throughput calc)
28
- """
29
- from __future__ import annotations
30
-
31
- import math
32
- import os
33
- import time
34
-
35
- import torch
36
- import lightning as L
37
-
38
- from hydra.config import (
39
- ADAM_BETAS,
40
- EMBEDDING_LR,
41
- FINAL_LR_FRAC,
42
- GPU_BF16_PEAK_FLOPS,
43
- MATRIX_LR,
44
- SCALAR_LR,
45
- UNEMBEDDING_LR,
46
- WARMUP_RATIO,
47
- WEIGHT_DECAY,
48
- PostSemClawConfig,
49
- )
50
- from hydra.model import PostSemClawModel
51
-
52
-
53
- # ---------------------------------------------------------------------------
54
- # LR / momentum / wd schedules — verbatim copy of hydra/training.py so the
55
- # curves match exactly. Kept here to avoid import cycles.
56
- # ---------------------------------------------------------------------------
57
-
58
-
59
- def _lr_multiplier(progress: float) -> float:
60
- if progress < WARMUP_RATIO:
61
- return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
62
- decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9)
63
- return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (
64
- 1 + math.cos(math.pi * decay_progress)
65
- )
66
-
67
-
68
- def _muon_momentum(step: int) -> float:
69
- frac = min(step / 300.0, 1.0)
70
- return (1 - frac) * 0.85 + frac * 0.95
71
-
72
-
73
- def _weight_decay(progress: float) -> float:
74
- return WEIGHT_DECAY * (1 - progress)
75
-
76
-
77
- # ---------------------------------------------------------------------------
78
-
79
-
80
- class HydraLightningModule(L.LightningModule):
81
- """Lightning wrapper. Public attrs: self.model, self.config."""
82
-
83
- def __init__(self, config: PostSemClawConfig):
84
- super().__init__()
85
- self.config = config
86
- self.model = PostSemClawModel(config)
87
- # Model weights init must be deferred to the correct device; done by
88
- # caller after construction (to match the meta-device + to_empty()
89
- # pattern used in the legacy loop).
90
-
91
- # Time-based progress tracks the legacy loop's semantics: LR cosine
92
- # is driven by wall-clock, not step count. We capture training start
93
- # in on_train_start and TIME_BUDGET from env.
94
- self.time_budget = float(
95
- int(os.environ.get("HYDRA_TIME_BUDGET", "300"))
96
- )
97
- self._train_start_time: float | None = None
98
- self._total_training_time = 0.0
99
- self._last_step_end: float | None = None
100
- self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
101
- self._flops_per_token = 0
102
- self._tokens_per_step = 0
103
-
104
- # Smoothed loss for the header-line log (matches legacy format).
105
- self._ema_beta = 0.9
106
- self._smooth_loss = 0.0
107
- self._bpt_ema = 0.0
108
- self._token_bytes: torch.Tensor | None = None
109
-
110
- # ------------------------------------------------------------------
111
- # Lifecycle
112
- # ------------------------------------------------------------------
113
-
114
- def on_train_start(self) -> None:
115
- self._train_start_time = time.time()
116
- self._last_step_end = self._train_start_time
117
- self._flops_per_token = self.model.estimate_flops()
118
- # Tokens processed per optimizer step (pre-accum).
119
- B = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
120
- T = int(os.environ.get("HYDRA_SEQ_LEN", "512"))
121
- self._tokens_per_step = B * T
122
-
123
- # Build/cache token_bytes LUT (for bits-per-byte live metric).
124
- import prepare as _p
125
- self._token_bytes = _p.get_token_bytes(device=self.device)
126
-
127
- def configure_optimizers(self):
128
- optimizer = self.model.setup_optimizer(
129
- unembedding_lr=UNEMBEDDING_LR,
130
- embedding_lr=EMBEDDING_LR,
131
- scalar_lr=SCALAR_LR,
132
- adam_betas=ADAM_BETAS,
133
- matrix_lr=MATRIX_LR,
134
- weight_decay=WEIGHT_DECAY,
135
- )
136
- return optimizer
137
-
138
- # ------------------------------------------------------------------
139
- # Training step. Lightning auto-handles: autocast (via precision flag
140
- # on Trainer), backward, grad-accum, zero_grad. We only:
141
- # - split batch into (x, y)
142
- # - forward through model (autocast is established by Trainer)
143
- # - return loss (grads flow from return)
144
- # ------------------------------------------------------------------
145
-
146
- def training_step(self, batch: torch.Tensor, batch_idx: int):
147
- # DataLoader produces (B, T+1) rows; split into input/target.
148
- # Lightning's default collate already moved batch to self.device via
149
- # the accelerator callback when pin_memory=True and device != cpu.
150
- if batch.dim() != 2:
151
- raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}")
152
- x = batch[:, :-1].contiguous()
153
- y = batch[:, 1:].contiguous()
154
-
155
- loss = self.model(x, y)
156
- # Lightning applies the grad-accum divisor automatically; we just
157
- # return the raw loss. loss.detach() is stored for logging.
158
- self._log_step(loss.detach(), y)
159
- return loss
160
-
161
- # ------------------------------------------------------------------
162
- # Optimizer step hook: update LR / momentum / WD using time-progress.
163
- # Runs once per optimizer step (after all accum micro-batches).
164
- # ------------------------------------------------------------------
165
-
166
- def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
167
- # Update schedules from wall-clock progress.
168
- now = time.time()
169
- if self._train_start_time is None:
170
- self._train_start_time = now
171
- self._last_step_end = now
172
- progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0)
173
-
174
- step = self.global_step
175
- lrm = _lr_multiplier(progress)
176
- mom = _muon_momentum(step)
177
- wd = _weight_decay(progress)
178
- for group in optimizer.param_groups:
179
- group["lr"] = group["initial_lr"] * lrm
180
- if group.get("kind") == "muon":
181
- group["momentum"] = mom
182
- group["weight_decay"] = wd
183
-
184
- # Grad clip (matches legacy loop). Lightning provides this via
185
- # Trainer(gradient_clip_val=1.0) but we want the exact call-site.
186
- torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
187
-
188
- # Hyena train-cache: we must flush accumulated micro-batch grads BACK
189
- # into the filter MLP params AFTER the accum-backward closure has run
190
- # but BEFORE the optimizer actually consumes the grads. Lightning
191
- # composes these so the closure runs inside optimizer.step(). We wrap
192
- # the closure to insert our flush at the exact right moment.
193
- #
194
- # Ordering within the wrapped closure:
195
- # 1. optimizer_closure() — runs all micro-batch forwards + backwards.
196
- # Each Hyena micro-batch backward accumulates into _k_leaf.grad.
197
- # 2. flush_hyena_pending_grads() — one-shot
198
- # torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter.
199
- # Now filter MLP / pos_emb / bias params have their correct grads.
200
- #
201
- # No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist.
202
- _has_flush = hasattr(self.model, "flush_hyena_pending_grads")
203
- if _has_flush:
204
- _orig_closure = optimizer_closure
205
-
206
- def _wrapped_closure():
207
- result = _orig_closure()
208
- self.model.flush_hyena_pending_grads()
209
- return result
210
-
211
- effective_closure = _wrapped_closure
212
- else:
213
- effective_closure = optimizer_closure
214
-
215
- # Run the step (this is what Lightning would have done for us).
216
- optimizer.step(closure=effective_closure)
217
- self.model.zero_grad(set_to_none=True)
218
-
219
- # Hyena filter-rfft cache invalidation. No-op if:
220
- # (a) no Hyena layers are in the model, or
221
- # (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0
222
- # (the operators never populated either cache)
223
- # In either case this is a handful of Python attribute resets.
224
- if hasattr(self.model, "invalidate_hyena_caches"):
225
- self.model.invalidate_hyena_caches()
226
-
227
- # Hestia QAT snap every N steps. Temperature anneals every step.
228
- progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0)
229
- self.model.hestia.anneal_temperature(progress_now)
230
- if self._hestia_interval > 0 and step % self._hestia_interval == 0:
231
- self.model.hestia.apply_to(self.model)
232
-
233
- # SDR SOM update when the model stashed an sdr in the last forward.
234
- _last_sdr = getattr(self.model, "_last_sdr", None)
235
- if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"):
236
- # x from the last training_step is not available here without
237
- # captured state; the legacy loop passed (x, _last_sdr). To keep
238
- # the interface clean we pass the last batch's x via a buffer.
239
- # Since _last_sdr is derived from idx, we reuse self._last_x.
240
- if getattr(self, "_last_x", None) is not None:
241
- self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr)
242
-
243
- # Advance the wall-clock counter for LR schedule (matches legacy
244
- # behavior which incremented only after the first warm-up step).
245
- dt = now - (self._last_step_end or now)
246
- self._last_step_end = now
247
- if step > 10:
248
- self._total_training_time += dt
249
-
250
- # ------------------------------------------------------------------
251
- # Logging — mirrors the step=NNNNN line format of the legacy loop so
252
- # grep/tee pipelines keep working.
253
- # ------------------------------------------------------------------
254
-
255
- def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None:
256
- # Stash the current x so optimizer_step can drive SOM update.
257
- self._last_x = None # reset; we will set it below.
258
- # We don't have x here (already discarded); emit a None marker that
259
- # the SOM hook will silently skip if absent.
260
-
261
- loss_f = float(loss.item())
262
- if not math.isfinite(loss_f) or loss_f > 100:
263
- # Let Lightning raise / the trainer callbacks handle this.
264
- self.log("train_loss_nan", 1.0)
265
- return
266
-
267
- step = self.global_step
268
- self._smooth_loss = (
269
- self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f
270
- )
271
- debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9)
272
- dt = max(time.time() - (self._last_step_end or time.time()), 1e-6)
273
- tps = int(self._tokens_per_step / dt) if dt > 0 else 0
274
- mfu = (
275
- 100.0
276
- * self._flops_per_token
277
- * self._tokens_per_step
278
- / dt
279
- / GPU_BF16_PEAK_FLOPS
280
- if dt > 0
281
- else 0.0
282
- )
283
-
284
- # bpb live: y flat -> token_bytes LUT -> avg bytes/token
285
- bpt = debiased / math.log(2)
286
- if self._token_bytes is not None:
287
- with torch.no_grad():
288
- y_flat = y.reshape(-1)
289
- nbytes = self._token_bytes[y_flat]
290
- mask = nbytes > 0
291
- denom = mask.sum().clamp(min=1).float()
292
- avg_bpt = (nbytes.float() * mask.float()).sum() / denom
293
- bpt_batch = float(avg_bpt.item())
294
- if step == 0 or self._bpt_ema <= 0.0:
295
- self._bpt_ema = bpt_batch
296
- else:
297
- self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch
298
- bpb = bpt / max(self._bpt_ema, 1e-6)
299
- vram = (
300
- torch.cuda.memory_allocated() / 1024 / 1024
301
- if torch.cuda.is_available()
302
- else 0.0
303
- )
304
-
305
- self.log_dict(
306
- {
307
- "train/loss": debiased,
308
- "train/bpb": bpb,
309
- "train/bpt": bpt,
310
- "train/tps": float(tps),
311
- "train/mfu": float(mfu),
312
- "train/vram_mib": float(vram),
313
- },
314
- prog_bar=False,
315
- on_step=True,
316
- on_epoch=False,
317
- )
318
-
319
- # Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..."
320
- print(
321
- f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} "
322
- f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} "
323
- f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
324
- f"vram={vram:.0f}MiB",
325
- flush=True,
326
- )
 
1
+ """LightningModule wrapping PostSemClawModel.
2
+
3
+ Thin adapter. The model and the MuonAdamW optimizer are unchanged. This
4
+ module implements:
5
+
6
+ • configure_optimizers — returns the existing MuonAdamW (subclass of
7
+ torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts
8
+ this directly.
9
+ • training_step — splits (B, T+1) batches into (x, y), forwards through
10
+ the model, logs loss / bpb / tps / mfu / vram. Preserves the
11
+ sampled-softmax path inside PostSemClawModel (no changes there).
12
+ • optimizer_step — before each step we update LR + muon momentum + WD
13
+ using the same time-progress schedule as hydra/training.py
14
+ (get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning
15
+ handles grad accumulation via Trainer(accumulate_grad_batches=N).
16
+
17
+ The SDR SOM update and Hestia QAT snap are called at the same cadence as
18
+ the legacy loop, but inline on the main thread (Lightning provides its own
19
+ callbacks for async work if we need to extract them later — keeping it
20
+ simple for now).
21
+
22
+ Env vars respected:
23
+ HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule
24
+ and as Trainer max_time
25
+ HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100)
26
+ HYDRA_BATCH_SIZE — device batch size (for throughput calc)
27
+ HYDRA_SEQ_LEN — sequence length (for throughput calc)
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import math
32
+ import os
33
+ import time
34
+
35
+ import torch
36
+ import lightning as L
37
+
38
+ from hydra.config import (
39
+ ADAM_BETAS,
40
+ EMBEDDING_LR,
41
+ FINAL_LR_FRAC,
42
+ GPU_BF16_PEAK_FLOPS,
43
+ MATRIX_LR,
44
+ SCALAR_LR,
45
+ UNEMBEDDING_LR,
46
+ WARMUP_RATIO,
47
+ WEIGHT_DECAY,
48
+ PostSemClawConfig,
49
+ )
50
+ from hydra.model import PostSemClawModel
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # LR / momentum / wd schedules — verbatim copy of hydra/training.py so the
55
+ # curves match exactly. Kept here to avoid import cycles.
56
+ # ---------------------------------------------------------------------------
57
+
58
+
59
+ def _lr_multiplier(progress: float) -> float:
60
+ if progress < WARMUP_RATIO:
61
+ return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
62
+ decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9)
63
+ return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (
64
+ 1 + math.cos(math.pi * decay_progress)
65
+ )
66
+
67
+
68
+ def _muon_momentum(step: int) -> float:
69
+ frac = min(step / 300.0, 1.0)
70
+ return (1 - frac) * 0.85 + frac * 0.95
71
+
72
+
73
+ def _weight_decay(progress: float) -> float:
74
+ return WEIGHT_DECAY * (1 - progress)
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+
79
+
80
+ class HydraLightningModule(L.LightningModule):
81
+ """Lightning wrapper. Public attrs: self.model, self.config."""
82
+
83
+ def __init__(self, config: PostSemClawConfig):
84
+ super().__init__()
85
+ self.config = config
86
+ self.model = PostSemClawModel(config)
87
+ # Model weights init must be deferred to the correct device; done by
88
+ # caller after construction (to match the meta-device + to_empty()
89
+ # pattern used in the legacy loop).
90
+
91
+ # Time-based progress tracks the legacy loop's semantics: LR cosine
92
+ # is driven by wall-clock, not step count. We capture training start
93
+ # in on_train_start and TIME_BUDGET from env.
94
+ self.time_budget = float(
95
+ int(os.environ.get("HYDRA_TIME_BUDGET", "300"))
96
+ )
97
+ self._train_start_time: float | None = None
98
+ self._total_training_time = 0.0
99
+ self._last_step_end: float | None = None
100
+ self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
101
+ self._flops_per_token = 0
102
+ self._tokens_per_step = 0
103
+
104
+ # Smoothed loss for the header-line log (matches legacy format).
105
+ self._ema_beta = 0.9
106
+ self._smooth_loss = 0.0
107
+ self._bpt_ema = 0.0
108
+ self._token_bytes: torch.Tensor | None = None
109
+
110
+ # ------------------------------------------------------------------
111
+ # Lifecycle
112
+ # ------------------------------------------------------------------
113
+
114
+ def on_train_start(self) -> None:
115
+ self._train_start_time = time.time()
116
+ self._last_step_end = self._train_start_time
117
+ self._flops_per_token = self.model.estimate_flops()
118
+ # Tokens processed per optimizer step (pre-accum).
119
+ B = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
120
+ T = int(os.environ.get("HYDRA_SEQ_LEN", "512"))
121
+ self._tokens_per_step = B * T
122
+
123
+ # Build/cache token_bytes LUT (for bits-per-byte live metric).
124
+ import prepare as _p
125
+ self._token_bytes = _p.get_token_bytes(device=self.device)
126
+
127
+ def configure_optimizers(self):
128
+ optimizer = self.model.setup_optimizer(
129
+ unembedding_lr=UNEMBEDDING_LR,
130
+ embedding_lr=EMBEDDING_LR,
131
+ scalar_lr=SCALAR_LR,
132
+ adam_betas=ADAM_BETAS,
133
+ matrix_lr=MATRIX_LR,
134
+ weight_decay=WEIGHT_DECAY,
135
+ )
136
+ return optimizer
137
+
138
+ # ------------------------------------------------------------------
139
+ # Training step. Lightning auto-handles: autocast (via precision flag
140
+ # on Trainer), backward, grad-accum, zero_grad. We only:
141
+ # - split batch into (x, y)
142
+ # - forward through model (autocast is established by Trainer)
143
+ # - return loss (grads flow from return)
144
+ # ------------------------------------------------------------------
145
+
146
+ def training_step(self, batch: torch.Tensor, batch_idx: int):
147
+ # DataLoader produces (B, T+1) rows; split into input/target.
148
+ # Lightning's default collate already moved batch to self.device via
149
+ # the accelerator callback when pin_memory=True and device != cpu.
150
+ if batch.dim() != 2:
151
+ raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}")
152
+ x = batch[:, :-1].contiguous()
153
+ y = batch[:, 1:].contiguous()
154
+
155
+ loss = self.model(x, y)
156
+ # Lightning applies the grad-accum divisor automatically; we just
157
+ # return the raw loss. loss.detach() is stored for logging.
158
+ self._log_step(loss.detach(), y)
159
+ return loss
160
+
161
+ # ------------------------------------------------------------------
162
+ # Optimizer step hook: update LR / momentum / WD using time-progress.
163
+ # Runs once per optimizer step (after all accum micro-batches).
164
+ # ------------------------------------------------------------------
165
+
166
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
167
+ # Update schedules from wall-clock progress.
168
+ now = time.time()
169
+ if self._train_start_time is None:
170
+ self._train_start_time = now
171
+ self._last_step_end = now
172
+ progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0)
173
+
174
+ step = self.global_step
175
+ lrm = _lr_multiplier(progress)
176
+ mom = _muon_momentum(step)
177
+ wd = _weight_decay(progress)
178
+ for group in optimizer.param_groups:
179
+ group["lr"] = group["initial_lr"] * lrm
180
+ if group.get("kind") == "muon":
181
+ group["momentum"] = mom
182
+ group["weight_decay"] = wd
183
+
184
+ # Grad clip (matches legacy loop). Lightning provides this via
185
+ # Trainer(gradient_clip_val=1.0) but we want the exact call-site.
186
+ torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
187
+
188
+ # Hyena train-cache: we must flush accumulated micro-batch grads BACK
189
+ # into the filter MLP params AFTER the accum-backward closure has run
190
+ # but BEFORE the optimizer actually consumes the grads. Lightning
191
+ # composes these so the closure runs inside optimizer.step(). We wrap
192
+ # the closure to insert our flush at the exact right moment.
193
+ #
194
+ # Ordering within the wrapped closure:
195
+ # 1. optimizer_closure() — runs all micro-batch forwards + backwards.
196
+ # Each Hyena micro-batch backward accumulates into _k_leaf.grad.
197
+ # 2. flush_hyena_pending_grads() — one-shot
198
+ # torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter.
199
+ # Now filter MLP / pos_emb / bias params have their correct grads.
200
+ #
201
+ # No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist.
202
+ _has_flush = hasattr(self.model, "flush_hyena_pending_grads")
203
+ if _has_flush:
204
+ _orig_closure = optimizer_closure
205
+
206
+ def _wrapped_closure():
207
+ result = _orig_closure()
208
+ self.model.flush_hyena_pending_grads()
209
+ return result
210
+
211
+ effective_closure = _wrapped_closure
212
+ else:
213
+ effective_closure = optimizer_closure
214
+
215
+ # Run the step (this is what Lightning would have done for us).
216
+ optimizer.step(closure=effective_closure)
217
+ self.model.zero_grad(set_to_none=True)
218
+
219
+ # Hyena filter-rfft cache invalidation. No-op if:
220
+ # (a) no Hyena layers are in the model, or
221
+ # (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0
222
+ # (the operators never populated either cache)
223
+ # In either case this is a handful of Python attribute resets.
224
+ if hasattr(self.model, "invalidate_hyena_caches"):
225
+ self.model.invalidate_hyena_caches()
226
+
227
+ # Hestia QAT snap every N steps. Temperature anneals every step.
228
+ progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0)
229
+ self.model.hestia.anneal_temperature(progress_now)
230
+ if self._hestia_interval > 0 and step % self._hestia_interval == 0:
231
+ self.model.hestia.apply_to(self.model)
232
+
233
+ # SDR SOM update when the model stashed an sdr in the last forward.
234
+ _last_sdr = getattr(self.model, "_last_sdr", None)
235
+ if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"):
236
+ # x from the last training_step is not available here without
237
+ # captured state; the legacy loop passed (x, _last_sdr). To keep
238
+ # the interface clean we pass the last batch's x via a buffer.
239
+ # Since _last_sdr is derived from idx, we reuse self._last_x.
240
+ if getattr(self, "_last_x", None) is not None:
241
+ self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr)
242
+
243
+ # Advance the wall-clock counter for LR schedule (matches legacy
244
+ # behavior which incremented only after the first warm-up step).
245
+ dt = now - (self._last_step_end or now)
246
+ self._last_step_end = now
247
+ if step > 10:
248
+ self._total_training_time += dt
249
+
250
+ # ------------------------------------------------------------------
251
+ # Logging — mirrors the step=NNNNN line format of the legacy loop so
252
+ # grep/tee pipelines keep working.
253
+ # ------------------------------------------------------------------
254
+
255
+ def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None:
256
+ # Stash the current x so optimizer_step can drive SOM update.
257
+ self._last_x = None # reset; we will set it below.
258
+ # We don't have x here (already discarded); emit a None marker that
259
+ # the SOM hook will silently skip if absent.
260
+
261
+ loss_f = float(loss.item())
262
+ if not math.isfinite(loss_f) or loss_f > 100:
263
+ # Let Lightning raise / the trainer callbacks handle this.
264
+ self.log("train_loss_nan", 1.0)
265
+ return
266
+
267
+ step = self.global_step
268
+ self._smooth_loss = (
269
+ self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f
270
+ )
271
+ debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9)
272
+ dt = max(time.time() - (self._last_step_end or time.time()), 1e-6)
273
+ tps = int(self._tokens_per_step / dt) if dt > 0 else 0
274
+ mfu = (
275
+ 100.0
276
+ * self._flops_per_token
277
+ * self._tokens_per_step
278
+ / dt
279
+ / GPU_BF16_PEAK_FLOPS
280
+ if dt > 0
281
+ else 0.0
282
+ )
283
+
284
+ # bpb live: y flat -> token_bytes LUT -> avg bytes/token
285
+ bpt = debiased / math.log(2)
286
+ if self._token_bytes is not None:
287
+ with torch.no_grad():
288
+ y_flat = y.reshape(-1)
289
+ nbytes = self._token_bytes[y_flat]
290
+ mask = nbytes > 0
291
+ denom = mask.sum().clamp(min=1).float()
292
+ avg_bpt = (nbytes.float() * mask.float()).sum() / denom
293
+ bpt_batch = float(avg_bpt.item())
294
+ if step == 0 or self._bpt_ema <= 0.0:
295
+ self._bpt_ema = bpt_batch
296
+ else:
297
+ self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch
298
+ bpb = bpt / max(self._bpt_ema, 1e-6)
299
+ vram = (
300
+ torch.cuda.memory_allocated() / 1024 / 1024
301
+ if torch.cuda.is_available()
302
+ else 0.0
303
+ )
304
+
305
+ self.log_dict(
306
+ {
307
+ "train/loss": debiased,
308
+ "train/bpb": bpb,
309
+ "train/bpt": bpt,
310
+ "train/tps": float(tps),
311
+ "train/mfu": float(mfu),
312
+ "train/vram_mib": float(vram),
313
+ },
314
+ prog_bar=False,
315
+ on_step=True,
316
+ on_epoch=False,
317
+ )
318
+
319
+ # Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..."
320
+ print(
321
+ f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} "
322
+ f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} "
323
+ f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
324
+ f"vram={vram:.0f}MiB",
325
+ flush=True,
326
+ )
overlay/hydra/model.py CHANGED
The diff for this file is too large to render. See raw diff
 
overlay/hydra/optimizer.py CHANGED
@@ -1,252 +1,252 @@
1
- """MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
2
-
3
- Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
4
-
5
- F1-F15 state preserved:
6
- - F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
7
- step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
8
- fresh. Persistent copies of param storage would be mutated between forward
9
- passes (via lerp_/sub_ on stacked tensors that share storage with params),
10
- triggering "modified in-place" errors on grad_accum=2 backwards.
11
- - F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
12
- - F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
13
- dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
14
- stream-capture deadlock. See .omc/muon_compile_bug.md for the full
15
- investigation.
16
- """
17
-
18
- from __future__ import annotations
19
-
20
- import os
21
-
22
- import torch
23
-
24
- # HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
25
- _HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
26
- _HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
27
-
28
-
29
- polar_express_coeffs = [
30
- (8.156554524902461, -22.48329292557795, 15.878769915207462),
31
- (4.042929935166739, -2.808917465908714, 0.5000178451051316),
32
- (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
33
- (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
34
- (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
35
- ]
36
-
37
-
38
- def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
39
- # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
40
- # for the whole group) driven from MuonAdamW._step_adamw below.
41
- grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
42
- p.mul_(1 - lr_t * wd_t)
43
- exp_avg.lerp_(grad, 1 - beta1_t)
44
- exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
45
- bias1 = 1 - beta1_t ** step_t
46
- bias2 = 1 - beta2_t ** step_t
47
- denom = (exp_avg_sq / bias2).sqrt() + eps_t
48
- step_size = lr_t / bias1
49
- p.add_(exp_avg / denom, alpha=-step_size)
50
-
51
-
52
- # ---------------------------------------------------------------------------
53
- # F15 muon_step_fused compile strategy.
54
- #
55
- # HYDRA_MUON_COMPILE env gate:
56
- # "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
57
- # Dynamic=True collapses the per-shape specialization cache so that N
58
- # Muon param-groups with N distinct shapes trigger 1 compile, not N.
59
- # mode="default" keeps the inductor codegen but disables cudagraphs,
60
- # which is what caused the step-17→18 silent deadlock observed under
61
- # the original dynamic=False configuration: cudagraph stream capture
62
- # can deadlock against HTM's CUDA kernels running on the default
63
- # stream, and the failure mode at capture-time is a silent hang
64
- # (100% GPU util, no log output, process state R).
65
- # "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
66
- # Keeps an escape hatch in case a future torch/inductor regression
67
- # reintroduces a deadlock.
68
- #
69
- # Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
70
- # alias-analysis edge case where inductor sees `g is stacked_grads` and
71
- # subsequent `stacked_grads.square()` operating on the post-lerp storage.
72
- # ---------------------------------------------------------------------------
73
- _MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
74
-
75
- def _maybe_compile(fn):
76
- if _MUON_COMPILE:
77
- # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
78
- # would enable) to avoid stream-capture deadlocks against HTM's CUDA
79
- # kernels. dynamic=True minimizes recompile count across param-group
80
- # shapes.
81
- return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
82
- return fn
83
-
84
- @_maybe_compile
85
- def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
86
- momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
87
- # Cast grads to param dtype AND clone defensively to break any alias
88
- # between the (freshly-stacked) input and the in-place lerp_ below.
89
- # Without this, inductor's alias analysis can emit code that reads from
90
- # post-mutation storage when computing `v_mean = g.square().mean(...)`.
91
- stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
92
- # Nesterov momentum
93
- momentum = momentum_t.to(stacked_grads.dtype)
94
- momentum_buffer.lerp_(stacked_grads, 1 - momentum)
95
- g = stacked_grads.lerp_(momentum_buffer, momentum)
96
- # Polar express orthogonalization
97
- X = g.bfloat16()
98
- X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
99
- if g.size(-2) > g.size(-1):
100
- for a, b, c in polar_express_coeffs[:ns_steps]:
101
- A = X.mT @ X
102
- B = b * A + c * (A @ A)
103
- X = a * X + X @ B
104
- else:
105
- for a, b, c in polar_express_coeffs[:ns_steps]:
106
- A = X @ X.mT
107
- B = b * A + c * (A @ A)
108
- X = a * X + B @ X
109
- g = X
110
- # NorMuon variance reduction
111
- # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
112
- # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
113
- beta2 = beta2_t.to(second_momentum_buffer.dtype)
114
- v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
115
- red_dim_size = g.size(red_dim)
116
- v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
117
- v_norm = v_norm_sq.sqrt()
118
- second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
119
- step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
120
- scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
121
- v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
122
- final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
123
- g = g * final_scale.to(g.dtype)
124
- # Cautious weight decay + parameter update
125
- lr = lr_t.to(g.dtype)
126
- wd = wd_t.to(g.dtype)
127
- mask = (g * stacked_params) >= 0
128
- stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
129
-
130
-
131
- class MuonAdamW(torch.optim.Optimizer):
132
- """Combined optimizer: Muon for 2D matrix params, AdamW for others."""
133
-
134
- def __init__(self, param_groups):
135
- super().__init__(param_groups, defaults={})
136
- # 0-D CPU tensors to avoid torch.compile recompilation when values change
137
- self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
138
- self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
139
- self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
140
- self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
141
- self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
142
- self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
143
- self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
144
- self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
145
- self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
146
- self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
147
-
148
- def _step_adamw(self, group):
149
- params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
150
- for p in group['params']:
151
- if p.grad is None:
152
- continue
153
- state = self.state[p]
154
- if not state:
155
- state['step'] = 0
156
- state['exp_avg'] = torch.zeros_like(p)
157
- state['exp_avg_sq'] = torch.zeros_like(p)
158
- if 'step_t' not in state:
159
- # _fused_adamw_ wants a per-param float step tensor on-device.
160
- state['step_t'] = torch.tensor(
161
- float(state['step']), dtype=torch.float32, device=p.device
162
- )
163
- state['step'] += 1
164
- params.append(p)
165
- grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
166
- exp_avgs.append(state['exp_avg'])
167
- exp_avg_sqs.append(state['exp_avg_sq'])
168
- state_steps.append(state['step_t'])
169
-
170
- if not params:
171
- return
172
-
173
- if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
174
- # _fused_adamw_ needs uniform (device, dtype) within a call, so
175
- # group by (device, dtype) — same pattern as PyTorch's own
176
- # AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
177
- buckets = {}
178
- for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
179
- key = (p.device, p.dtype)
180
- buckets.setdefault(key, ([], [], [], [], []))
181
- b_p, b_g, b_ea, b_es, b_st = buckets[key]
182
- b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
183
-
184
- lr_f = float(group['lr'])
185
- b1_f = float(group['betas'][0])
186
- b2_f = float(group['betas'][1])
187
- wd_f = float(group['weight_decay'])
188
- eps_f = float(group['eps'])
189
- for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
190
- torch._foreach_add_(b_st, 1.0)
191
- torch._fused_adamw_(
192
- b_p, b_g, b_ea, b_es,
193
- [], # max_exp_avg_sqs unused (amsgrad=False)
194
- b_st,
195
- amsgrad=False,
196
- lr=lr_f, beta1=b1_f, beta2=b2_f,
197
- weight_decay=wd_f, eps=eps_f,
198
- maximize=False,
199
- grad_scale=None, found_inf=None,
200
- )
201
- return
202
-
203
- # Fallback per-param path.
204
- self._adamw_lr_t.fill_(group['lr'])
205
- self._adamw_beta1_t.fill_(group['betas'][0])
206
- self._adamw_beta2_t.fill_(group['betas'][1])
207
- self._adamw_eps_t.fill_(group['eps'])
208
- self._adamw_wd_t.fill_(group['weight_decay'])
209
- for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
210
- self._adamw_step_t.fill_(self.state[p]['step'])
211
- adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
212
- self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
213
- self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
214
-
215
- def _step_muon(self, group):
216
- params = [p for p in group['params'] if p.grad is not None]
217
- if not params:
218
- return
219
- p = params[0]
220
- state = self.state[p]
221
- num_params = len(params)
222
- shape, device, dtype = p.shape, p.device, p.dtype
223
- if "momentum_buffer" not in state:
224
- state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
225
- red_dim = -1 if shape[-2] >= shape[-1] else -2
226
- if "second_momentum_buffer" not in state:
227
- # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
228
- full_shape = (num_params, *shape)
229
- state_shape = list(full_shape)
230
- state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
231
- state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
232
- # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
233
- # This was the autograd-safety fix that unblocks grad_accum>=2.
234
- stacked_grads = torch.stack([p.grad for p in params])
235
- stacked_params = torch.stack(params)
236
- self._muon_momentum_t.fill_(group["momentum"])
237
- self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
238
- self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
239
- self._muon_wd_t.fill_(group["weight_decay"])
240
- muon_step_fused(stacked_grads, stacked_params,
241
- state["momentum_buffer"], state["second_momentum_buffer"],
242
- self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
243
- self._muon_beta2_t, group["ns_steps"], red_dim)
244
- torch._foreach_copy_(params, list(stacked_params.unbind(0)))
245
-
246
- @torch.no_grad()
247
- def step(self):
248
- for group in self.param_groups:
249
- if group['kind'] == 'adamw':
250
- self._step_adamw(group)
251
- elif group['kind'] == 'muon':
252
- self._step_muon(group)
 
1
+ """MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
2
+
3
+ Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
4
+
5
+ F1-F15 state preserved:
6
+ - F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
7
+ step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
8
+ fresh. Persistent copies of param storage would be mutated between forward
9
+ passes (via lerp_/sub_ on stacked tensors that share storage with params),
10
+ triggering "modified in-place" errors on grad_accum=2 backwards.
11
+ - F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
12
+ - F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
13
+ dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
14
+ stream-capture deadlock. See .omc/muon_compile_bug.md for the full
15
+ investigation.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+
22
+ import torch
23
+
24
+ # HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
25
+ _HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
26
+ _HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
27
+
28
+
29
+ polar_express_coeffs = [
30
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
31
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
32
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
33
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
34
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
35
+ ]
36
+
37
+
38
+ def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
39
+ # Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
40
+ # for the whole group) driven from MuonAdamW._step_adamw below.
41
+ grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
42
+ p.mul_(1 - lr_t * wd_t)
43
+ exp_avg.lerp_(grad, 1 - beta1_t)
44
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
45
+ bias1 = 1 - beta1_t ** step_t
46
+ bias2 = 1 - beta2_t ** step_t
47
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
48
+ step_size = lr_t / bias1
49
+ p.add_(exp_avg / denom, alpha=-step_size)
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # F15 muon_step_fused compile strategy.
54
+ #
55
+ # HYDRA_MUON_COMPILE env gate:
56
+ # "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
57
+ # Dynamic=True collapses the per-shape specialization cache so that N
58
+ # Muon param-groups with N distinct shapes trigger 1 compile, not N.
59
+ # mode="default" keeps the inductor codegen but disables cudagraphs,
60
+ # which is what caused the step-17→18 silent deadlock observed under
61
+ # the original dynamic=False configuration: cudagraph stream capture
62
+ # can deadlock against HTM's CUDA kernels running on the default
63
+ # stream, and the failure mode at capture-time is a silent hang
64
+ # (100% GPU util, no log output, process state R).
65
+ # "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
66
+ # Keeps an escape hatch in case a future torch/inductor regression
67
+ # reintroduces a deadlock.
68
+ #
69
+ # Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
70
+ # alias-analysis edge case where inductor sees `g is stacked_grads` and
71
+ # subsequent `stacked_grads.square()` operating on the post-lerp storage.
72
+ # ---------------------------------------------------------------------------
73
+ _MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
74
+
75
+ def _maybe_compile(fn):
76
+ if _MUON_COMPILE:
77
+ # mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
78
+ # would enable) to avoid stream-capture deadlocks against HTM's CUDA
79
+ # kernels. dynamic=True minimizes recompile count across param-group
80
+ # shapes.
81
+ return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
82
+ return fn
83
+
84
+ @_maybe_compile
85
+ def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
86
+ momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
87
+ # Cast grads to param dtype AND clone defensively to break any alias
88
+ # between the (freshly-stacked) input and the in-place lerp_ below.
89
+ # Without this, inductor's alias analysis can emit code that reads from
90
+ # post-mutation storage when computing `v_mean = g.square().mean(...)`.
91
+ stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
92
+ # Nesterov momentum
93
+ momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
94
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
95
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
96
+ # Polar express orthogonalization
97
+ X = g.bfloat16()
98
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
99
+ if g.size(-2) > g.size(-1):
100
+ for a, b, c in polar_express_coeffs[:ns_steps]:
101
+ A = X.mT @ X
102
+ B = b * A + c * (A @ A)
103
+ X = a * X + X @ B
104
+ else:
105
+ for a, b, c in polar_express_coeffs[:ns_steps]:
106
+ A = X @ X.mT
107
+ B = b * A + c * (A @ A)
108
+ X = a * X + B @ X
109
+ g = X
110
+ # NorMuon variance reduction
111
+ # Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
112
+ # float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
113
+ beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
114
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
115
+ red_dim_size = g.size(red_dim)
116
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
117
+ v_norm = v_norm_sq.sqrt()
118
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
119
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
120
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
121
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
122
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
123
+ g = g * final_scale.to(g.dtype)
124
+ # Cautious weight decay + parameter update
125
+ lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
126
+ wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
127
+ mask = (g * stacked_params) >= 0
128
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
129
+
130
+
131
+ class MuonAdamW(torch.optim.Optimizer):
132
+ """Combined optimizer: Muon for 2D matrix params, AdamW for others."""
133
+
134
+ def __init__(self, param_groups):
135
+ super().__init__(param_groups, defaults={})
136
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
137
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
138
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
139
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
140
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
141
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
142
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
143
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
144
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
145
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
146
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
147
+
148
+ def _step_adamw(self, group):
149
+ params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
150
+ for p in group['params']:
151
+ if p.grad is None:
152
+ continue
153
+ state = self.state[p]
154
+ if not state:
155
+ state['step'] = 0
156
+ state['exp_avg'] = torch.zeros_like(p)
157
+ state['exp_avg_sq'] = torch.zeros_like(p)
158
+ if 'step_t' not in state:
159
+ # _fused_adamw_ wants a per-param float step tensor on-device.
160
+ state['step_t'] = torch.tensor(
161
+ float(state['step']), dtype=torch.float32, device=p.device
162
+ )
163
+ state['step'] += 1
164
+ params.append(p)
165
+ grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
166
+ exp_avgs.append(state['exp_avg'])
167
+ exp_avg_sqs.append(state['exp_avg_sq'])
168
+ state_steps.append(state['step_t'])
169
+
170
+ if not params:
171
+ return
172
+
173
+ if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
174
+ # _fused_adamw_ needs uniform (device, dtype) within a call, so
175
+ # group by (device, dtype) — same pattern as PyTorch's own
176
+ # AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
177
+ buckets = {}
178
+ for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
179
+ key = (p.device, p.dtype)
180
+ buckets.setdefault(key, ([], [], [], [], []))
181
+ b_p, b_g, b_ea, b_es, b_st = buckets[key]
182
+ b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
183
+
184
+ lr_f = float(group['lr'])
185
+ b1_f = float(group['betas'][0])
186
+ b2_f = float(group['betas'][1])
187
+ wd_f = float(group['weight_decay'])
188
+ eps_f = float(group['eps'])
189
+ for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
190
+ torch._foreach_add_(b_st, 1.0)
191
+ torch._fused_adamw_(
192
+ b_p, b_g, b_ea, b_es,
193
+ [], # max_exp_avg_sqs unused (amsgrad=False)
194
+ b_st,
195
+ amsgrad=False,
196
+ lr=lr_f, beta1=b1_f, beta2=b2_f,
197
+ weight_decay=wd_f, eps=eps_f,
198
+ maximize=False,
199
+ grad_scale=None, found_inf=None,
200
+ )
201
+ return
202
+
203
+ # Fallback per-param path.
204
+ self._adamw_lr_t.fill_(group['lr'])
205
+ self._adamw_beta1_t.fill_(group['betas'][0])
206
+ self._adamw_beta2_t.fill_(group['betas'][1])
207
+ self._adamw_eps_t.fill_(group['eps'])
208
+ self._adamw_wd_t.fill_(group['weight_decay'])
209
+ for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
210
+ self._adamw_step_t.fill_(self.state[p]['step'])
211
+ adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
212
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
213
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
214
+
215
+ def _step_muon(self, group):
216
+ params = [p for p in group['params'] if p.grad is not None]
217
+ if not params:
218
+ return
219
+ p = params[0]
220
+ state = self.state[p]
221
+ num_params = len(params)
222
+ shape, device, dtype = p.shape, p.device, p.dtype
223
+ if "momentum_buffer" not in state:
224
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
225
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
226
+ if "second_momentum_buffer" not in state:
227
+ # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
228
+ full_shape = (num_params, *shape)
229
+ state_shape = list(full_shape)
230
+ state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
231
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
232
+ # F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
233
+ # This was the autograd-safety fix that unblocks grad_accum>=2.
234
+ stacked_grads = torch.stack([p.grad for p in params])
235
+ stacked_params = torch.stack(params)
236
+ self._muon_momentum_t.fill_(group["momentum"])
237
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
238
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
239
+ self._muon_wd_t.fill_(group["weight_decay"])
240
+ muon_step_fused(stacked_grads, stacked_params,
241
+ state["momentum_buffer"], state["second_momentum_buffer"],
242
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
243
+ self._muon_beta2_t, group["ns_steps"], red_dim)
244
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
245
+
246
+ @torch.no_grad()
247
+ def step(self):
248
+ for group in self.param_groups:
249
+ if group['kind'] == 'adamw':
250
+ self._step_adamw(group)
251
+ elif group['kind'] == 'muon':
252
+ self._step_muon(group)
overlay/hydra/reality_bridge.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class RealityBridgeOutput:
11
+ reality: torch.Tensor
12
+ poincare: torch.Tensor
13
+ l0_indices: torch.Tensor
14
+ l0_values: torch.Tensor
15
+
16
+
17
+ class RealityPoincareBridge(nn.Module):
18
+ """Default-off SEM-Claw continuous→discrete bridge.
19
+
20
+ PyTorch GEMM creates a compact 133-d reality latent, then a differentiable
21
+ Poincare-disk projection is kept for metrics/regularizers while a detached
22
+ int16 L0/top-k index buffer feeds Engram/Cantor sparse retrieval. This is a
23
+ production-shaped version of rs.md's Poincare/Reality Buffer without adding
24
+ speculative E7 machinery to the hot path.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ d_model: int,
30
+ d_reality: int = 133,
31
+ d_poincare: int = 2,
32
+ l0_k: int = 64,
33
+ ) -> None:
34
+ super().__init__()
35
+ if d_model <= 0:
36
+ raise ValueError(f"d_model must be positive, got {d_model}")
37
+ if d_reality <= 0:
38
+ raise ValueError(f"d_reality must be positive, got {d_reality}")
39
+ if d_poincare != 2:
40
+ raise ValueError("Poincare bridge currently expects d_poincare=2")
41
+ if l0_k <= 0:
42
+ raise ValueError(f"l0_k must be positive, got {l0_k}")
43
+ self.d_model = int(d_model)
44
+ self.d_reality = int(d_reality)
45
+ self.d_poincare = int(d_poincare)
46
+ self.l0_k = min(int(l0_k), self.d_reality)
47
+ self.to_reality = nn.Linear(d_model, d_reality, bias=False)
48
+ self.to_tangent2 = nn.Linear(d_reality, d_poincare, bias=False)
49
+ nn.init.normal_(self.to_reality.weight, mean=0.0, std=0.02)
50
+ nn.init.normal_(self.to_tangent2.weight, mean=0.0, std=0.02)
51
+
52
+ @staticmethod
53
+ def poincare_expmap0(tangent2: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
54
+ t = tangent2.float()
55
+ r = t.norm(dim=-1, keepdim=True).clamp_min(eps)
56
+ y = torch.tanh(r) * (t / r)
57
+ return y.to(tangent2.dtype)
58
+
59
+ def forward(self, x: torch.Tensor) -> RealityBridgeOutput:
60
+ if x.shape[-1] != self.d_model:
61
+ raise ValueError(f"expected last dim {self.d_model}, got {x.shape[-1]}")
62
+ reality = self.to_reality(x)
63
+ tangent2 = self.to_tangent2(reality)
64
+ poincare = self.poincare_expmap0(tangent2)
65
+ vals, idx = reality.float().abs().topk(self.l0_k, dim=-1)
66
+ return RealityBridgeOutput(
67
+ reality=reality,
68
+ poincare=poincare,
69
+ l0_indices=idx.to(torch.int16),
70
+ l0_values=vals.to(reality.dtype),
71
+ )
overlay/hydra/training.py CHANGED
@@ -1,948 +1,967 @@
1
- """HYDRA training entry: setup, train loop, eval, summary.
2
-
3
- Extracted from the monolithic train.py (W1 modularization). Semantics
4
- preserved. Public entrypoint: `main()`.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import gc
10
- import json
11
- import math
12
- import os
13
- import sys
14
- import threading
15
- import time
16
- from dataclasses import asdict
17
- from pathlib import Path
18
-
19
- import torch
20
-
21
- # Line-buffered stdout so `python -u train.py | tee run.log | grep step` is
22
- # live (no \r overwrite, no 4k block-buffered pipe stalls). Safe on Python
23
- # 3.7+ where io.TextIOWrapper.reconfigure exists.
24
- try:
25
- sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined]
26
- except Exception:
27
- pass
28
-
29
- from hydra.config import (
30
- ADAM_BETAS, CURRICULUM_SHORT_SEQ_LEN, CURRICULUM_SHORT_STEPS,
31
- D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMA_DECAY, EMBEDDING_LR,
32
- ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
33
- FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
34
- N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
35
- UNEMBEDDING_LR, USE_EMA, WARMUP_RATIO, WEIGHT_DECAY,
36
- )
37
- from hydra.diffusion_loss import mdlm_masked_forward_process, mdlm_rb_loss
38
- from hydra.eval import run_factual_english, run_factual_probes
39
- from hydra.model import PostSemClawModel
40
-
41
- import prepare as _prepare_mod
42
- from prepare import MAX_SEQ_LEN, TIME_BUDGET as _TIME_BUDGET, Tokenizer, evaluate_bpb as _evaluate_bpb_shards, get_token_bytes, make_dataloader as _make_dataloader_shards
43
-
44
- # Streaming Nemotron path (Super3 recipe). Opt-in via HYDRA_USE_NEMOTRON=1.
45
- if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
46
- import prepare_nemotron as _p_nemo
47
- make_dataloader = _p_nemo.make_dataloader
48
- evaluate_bpb = _p_nemo.evaluate_bpb
49
- else:
50
- make_dataloader = _make_dataloader_shards
51
- evaluate_bpb = _evaluate_bpb_shards
52
-
53
- TIME_BUDGET = int(os.environ.get("HYDRA_TIME_BUDGET", str(_TIME_BUDGET)))
54
- _prepare_mod.TIME_BUDGET = TIME_BUDGET # sync for evaluate_bpb
55
-
56
- CACHE_DIR = Path.home() / ".cache" / "autoresearch"
57
- LATEST_CKPT = CACHE_DIR / "latest.pt"
58
- PRETRAIN_FINAL_CKPT = CACHE_DIR / "pretrain_final.pt"
59
- FAILED_CKPT = CACHE_DIR / "latest_failed.pt" # crash/FAIL path — never overwrites good
60
- BEST_CKPT = CACHE_DIR / "best_bpb.pt" # lowest val_bpb seen
61
- CKPT_INTERVAL = int(os.environ.get("HYDRA_CKPT_INTERVAL", "250"))
62
- CKPT_ROTATIONS = int(os.environ.get("HYDRA_CKPT_ROTATIONS", "3")) # how many .N backups to keep
63
- RESUME_CKPT = os.environ.get("HYDRA_RESUME_CKPT", str(LATEST_CKPT))
64
-
65
- # MDLM (Masked Diffusion LM) Rao-Blackwellized ELBO loss path.
66
- # HYDRA_USE_MDLM=1 : switch training loss from AR sampled-softmax CE
67
- # to MDLM RB weighted CE (arXiv:2406.07524).
68
- # HYDRA_MDLM_MASK_ID=N : token id used for the MASK sentinel (default:
69
- # last valid id, vocab_size - 1). Ensure this id
70
- # never appears in training targets — typical
71
- # practice is to reserve it.
72
- # HYDRA_MDLM_SCHEDULE=loglinear|linear : noise schedule (default loglinear).
73
- # When enabled, the per-step flow is:
74
- # 1. mdlm_masked_forward_process(y) -> (x_noised, mask_positions, weights)
75
- # 2. logits = model(x_noised) (no targets -> full V logits)
76
- # 3. loss = mdlm_rb_loss(logits, y, mask_positions, weights)
77
- # Sampled-softmax is bypassed in this path because the RB ELBO needs
78
- # full-vocab logits on masked positions.
79
- USE_MDLM = os.environ.get("HYDRA_USE_MDLM", "0") == "1"
80
- MDLM_MASK_ID = int(os.environ.get("HYDRA_MDLM_MASK_ID", "-1")) # -1 => default to vocab_size-1 at runtime
81
- MDLM_SCHEDULE = os.environ.get("HYDRA_MDLM_SCHEDULE", "loglinear")
82
-
83
-
84
- # ---------------------------------------------------------------------------
85
- # Schedules
86
- # ---------------------------------------------------------------------------
87
-
88
- def get_lr_multiplier(progress: float) -> float:
89
- if progress < WARMUP_RATIO:
90
- return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
91
- decay_progress = (progress - WARMUP_RATIO) / (1.0 - WARMUP_RATIO)
92
- return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (1 + math.cos(math.pi * decay_progress))
93
-
94
-
95
- def get_muon_momentum(step: int) -> float:
96
- frac = min(step / 300, 1)
97
- return (1 - frac) * 0.85 + frac * 0.95
98
-
99
-
100
- def get_weight_decay(progress: float) -> float:
101
- return WEIGHT_DECAY * (1 - progress)
102
-
103
-
104
- _CKPT_WORKER_THREAD: threading.Thread | None = None
105
-
106
-
107
- def _ckpt_snapshot_state_dicts(
108
- model: PostSemClawModel,
109
- optimizer: torch.optim.Optimizer,
110
- ) -> tuple[dict, dict]:
111
- """Detach + CPU-clone every tensor so a bg thread can serialize safely
112
- while the main loop keeps mutating live weights/optimizer state."""
113
- msd = {k: (v.detach().to("cpu", copy=True) if torch.is_tensor(v) else v)
114
- for k, v in model.state_dict().items()}
115
- # optimizer.state_dict() is a nested dict; walk it.
116
- osd_raw = optimizer.state_dict()
117
-
118
- def _to_cpu(obj):
119
- if torch.is_tensor(obj):
120
- return obj.detach().to("cpu", copy=True)
121
- if isinstance(obj, dict):
122
- return {k: _to_cpu(v) for k, v in obj.items()}
123
- if isinstance(obj, list):
124
- return [_to_cpu(v) for v in obj]
125
- if isinstance(obj, tuple):
126
- return tuple(_to_cpu(v) for v in obj)
127
- return obj
128
-
129
- osd = _to_cpu(osd_raw)
130
- return msd, osd
131
-
132
-
133
- def save_ckpt(
134
- model: PostSemClawModel,
135
- optimizer: torch.optim.Optimizer,
136
- config: PostSemClawConfig,
137
- step: int,
138
- total_training_time: float,
139
- smooth_train_loss: float,
140
- bpt_ema: float,
141
- epoch: int,
142
- path: Path,
143
- *,
144
- val_bpb: float | None = None,
145
- blocking: bool = False,
146
- ) -> None:
147
- """Save a training checkpoint.
148
-
149
- Default behavior is async: the GPU→CPU state_dict clone runs on the main
150
- thread (unavoidable; needs to happen before the next optimizer.step that
151
- mutates live weights), then `torch.save` is dispatched to a daemon
152
- worker thread. The next call joins any still-running prior save so only
153
- one disk write is in flight.
154
-
155
- `blocking=True` restores the original synchronous behavior — used for
156
- end-of-training saves where correctness on process exit matters.
157
- """
158
- global _CKPT_WORKER_THREAD
159
- try:
160
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
161
- msd, osd = _ckpt_snapshot_state_dicts(model, optimizer)
162
- # asdict() recursively converts dataclass fields to a dict and
163
- # renders tuples as lists. hyena_layers therefore round-trips as a
164
- # JSON-safe list; config_from_dict normalizes it back to a tuple.
165
- payload = {
166
- "model_state_dict": msd,
167
- "optimizer_state_dict": osd,
168
- "config": asdict(config),
169
- "step": step,
170
- "epoch": epoch,
171
- "train_seconds": total_training_time,
172
- "smoothed_loss": smooth_train_loss,
173
- "bpt_ema": bpt_ema,
174
- "val_bpb": val_bpb,
175
- }
176
- path_str = str(path)
177
-
178
- def _rotate(p: str) -> None:
179
- """Keep up to CKPT_ROTATIONS previous versions as p.1, p.2, ..."""
180
- if CKPT_ROTATIONS <= 0:
181
- return
182
- try:
183
- # Walk from oldest to newest so we don't clobber newer with older.
184
- for i in range(CKPT_ROTATIONS, 0, -1):
185
- src = f"{p}.{i-1}" if i > 1 else p
186
- dst = f"{p}.{i}"
187
- if os.path.exists(src):
188
- os.replace(src, dst)
189
- except Exception as e:
190
- # Rotation is best-effort; never block a save on it.
191
- print(f"[ckpt] rotate warn {p}: {type(e).__name__}: {e}", flush=True)
192
-
193
- def _write():
194
- try:
195
- _rotate(path_str)
196
- tmp = path_str + ".tmp"
197
- torch.save(payload, tmp)
198
- os.replace(tmp, path_str)
199
- print(f"[ckpt] saved {path_str} (step={step})", flush=True)
200
- except Exception as e:
201
- print(f"[ckpt] SAVE FAILED {path_str}: {type(e).__name__}: {e}", flush=True)
202
-
203
- if blocking:
204
- _write()
205
- return
206
-
207
- # Join previous writer so at most one torch.save runs at a time.
208
- if _CKPT_WORKER_THREAD is not None and _CKPT_WORKER_THREAD.is_alive():
209
- _CKPT_WORKER_THREAD.join()
210
- _CKPT_WORKER_THREAD = threading.Thread(
211
- target=_write, daemon=True, name=f"ckpt-save-{step}"
212
- )
213
- _CKPT_WORKER_THREAD.start()
214
- except Exception as e:
215
- print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
216
-
217
-
218
- def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
219
- """Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
220
-
221
- Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
222
- older checkpoints, and list-ified tuples are coerced back to tuples so
223
- the dataclass keeps its declared types.
224
-
225
- This is the ckpt-safe inverse of `asdict(config)` used by save_ckpt and
226
- guarantees that a resume path can rebuild the exact same model topology
227
- (Mamba3 vs HyenaBlock per layer) regardless of env-var state at resume.
228
- """
229
- # Only keep keys that are actually declared on PostSemClawConfig — extra
230
- # keys in older/newer checkpoints must not crash construction.
231
- field_names = {f.name for f in PostSemClawConfig.__dataclass_fields__.values()}
232
- filtered = {k: v for k, v in cfg_dict.items() if k in field_names}
233
- # asdict renders tuple[int,...] as list[int]; coerce back so the model
234
- # builder sees the declared type.
235
- if "hyena_layers" in filtered and filtered["hyena_layers"] is not None:
236
- filtered["hyena_layers"] = tuple(sorted(int(x) for x in filtered["hyena_layers"]))
237
- return PostSemClawConfig(**filtered)
238
-
239
-
240
- def _try_load_ckpt(path: Path, model, optimizer, device):
241
- """Attempt to load a single ckpt. Returns the tuple on success, None on any failure."""
242
- if not path.exists():
243
- return None
244
- ckpt = torch.load(str(path), map_location=device, weights_only=False)
245
- state = ckpt.get("model_state_dict", ckpt)
246
- missing, unexpected = model.load_state_dict(state, strict=False)
247
- if missing:
248
- print(f"[ckpt] {path.name} missing={len(missing)}", flush=True)
249
- if unexpected:
250
- print(f"[ckpt] {path.name} unexpected={len(unexpected)}", flush=True)
251
- optimizer_state = ckpt.get("optimizer_state_dict")
252
- if optimizer_state is not None:
253
- try:
254
- optimizer.load_state_dict(optimizer_state)
255
- except Exception as e:
256
- print(f"[ckpt] optimizer restore failed from {path.name}: {type(e).__name__}: {e}", flush=True)
257
- step = int(ckpt.get("step", 0))
258
- total_training_time = float(ckpt.get("train_seconds", 0.0))
259
- smooth_train_loss = float(ckpt.get("smoothed_loss", 0.0))
260
- bpt_ema = float(ckpt.get("bpt_ema", 0.0))
261
- epoch = int(ckpt.get("epoch", 0))
262
- print(
263
- f"[ckpt] resumed {path} step={step} train_seconds={total_training_time:.1f}",
264
- flush=True,
265
- )
266
- # Warn if resuming a schedule-exhausted ckpt — user is probably warm-starting.
267
- budget = float(os.environ.get("HYDRA_TIME_BUDGET", "0") or 0)
268
- if budget and total_training_time >= 0.99 * budget:
269
- print(
270
- f"[ckpt] WARNING: resumed ckpt used {total_training_time:.0f}s of {budget:.0f}s "
271
- f"budget. LR schedule is essentially exhausted. "
272
- f"Set HYDRA_WARMSTART=1 to reset optimizer + scheduler and keep only weights.",
273
- flush=True,
274
- )
275
- return step, total_training_time, smooth_train_loss, bpt_ema, epoch
276
-
277
-
278
- def maybe_resume_ckpt(
279
- model: PostSemClawModel,
280
- optimizer: torch.optim.Optimizer,
281
- device: torch.device,
282
- ) -> tuple[int, float, float, float, int]:
283
- if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
284
- print("[ckpt] resume disabled; starting fresh", flush=True)
285
- return 0, 0.0, 0.0, 0.0, 0
286
-
287
- resume_path = Path(os.path.expanduser(RESUME_CKPT))
288
- # Try the primary path, then rotated backups. This is crucial because a
289
- # partial / killed torch.save on the primary path would leave a corrupt
290
- # file. If that fails we fall back to latest.pt.1, .2, .3 automatically.
291
- candidates: list[Path] = [resume_path]
292
- for i in range(1, CKPT_ROTATIONS + 1):
293
- candidates.append(Path(str(resume_path) + f".{i}"))
294
-
295
- for cand in candidates:
296
- if not cand.exists():
297
- continue
298
- try:
299
- result = _try_load_ckpt(cand, model, optimizer, device)
300
- if result is not None:
301
- if cand != resume_path:
302
- print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
303
- return result
304
- except Exception as e:
305
- print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
306
- continue
307
-
308
- print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
309
- return 0, 0.0, 0.0, 0.0, 0
310
-
311
-
312
- # ---------------------------------------------------------------------------
313
- # Main entry
314
- # ---------------------------------------------------------------------------
315
-
316
- def main() -> None:
317
- t_start = time.time()
318
- torch.manual_seed(SEED)
319
- torch.cuda.manual_seed(SEED)
320
- # Precision / kernel-selection knobs for peak throughput on Ampere.
321
- # - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
322
- # - allow_tf32 : explicit for both matmul + cudnn paths
323
- # - cudnn.benchmark : env-gated (HYDRA_CUDNN_BENCHMARK, default OFF).
324
- # TRUE can lock in a locally-better-but-globally-slower algorithm
325
- # after the autotune phase ends, causing tps to degrade 15-20%
326
- # over the first ~100 steps. Observed 2026-04-22 and confirmed by
327
- # differential profiling. Default is now FALSE; set =1 only if you
328
- # see a specific workload where benchmark helps sustained tps.
329
- torch.set_float32_matmul_precision("high")
330
- torch.backends.cuda.matmul.allow_tf32 = True
331
- torch.backends.cudnn.allow_tf32 = True
332
- torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
333
- device = torch.device("cuda")
334
- autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
335
-
336
- # Streaming path skips prepare.py (which normally trains the tokenizer
337
- # and builds the retina), so we must materialize both before model init.
 
 
 
 
 
 
338
  if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
339
  _p_nemo.ensure_tokenizer()
340
- if os.environ.get("HYDRA_THROUGHPUT_MODE", "0") != "1":
341
- # Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
342
- # returns in seconds; otherwise build_retina streams Nemotron docs to
343
- # compute cooccurrence + train SOM, then uploads back to the cache.
344
- import subsystems.sdr_retina as _sdr_retina
345
- _sdr_retina.build_retina()
346
- tokenizer = Tokenizer.from_directory()
347
- vocab_size = tokenizer.get_vocab_size()
348
- print(f"Vocab size: {vocab_size:,}")
349
-
350
- config = PostSemClawConfig(
351
- sequence_len=MAX_SEQ_LEN,
352
- vocab_size=vocab_size,
353
- n_layer=N_LAYER,
354
- d_model=D_MODEL,
355
- d_state=D_STATE,
356
- headdim=HEADDIM,
357
- n_heads=N_HEADS,
358
- expand=EXPAND,
359
- engram_n_columns=ENGRAM_N_COLUMNS,
360
- engram_key_dim=ENGRAM_KEY_DIM,
361
- engram_layer_idx=ENGRAM_LAYER_IDX,
362
- )
363
- print(f"Model config: {asdict(config)}")
364
-
365
- with torch.device("meta"):
366
- model = PostSemClawModel(config)
367
- model.to_empty(device=device)
368
- model.init_weights()
369
-
370
- param_counts = model.num_scaling_params()
371
- print("Parameter counts:")
372
- for key, value in param_counts.items():
373
- print(f" {key:24s}: {value:,}")
374
- num_params = param_counts['total']
375
- num_flops_per_token = model.estimate_flops()
376
- print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
377
-
378
- tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
379
- assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
380
- grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
381
-
382
- optimizer = model.setup_optimizer(
383
- unembedding_lr=UNEMBEDDING_LR,
384
- embedding_lr=EMBEDDING_LR,
385
- scalar_lr=SCALAR_LR,
386
- adam_betas=ADAM_BETAS,
387
- matrix_lr=MATRIX_LR,
388
- weight_decay=WEIGHT_DECAY,
389
- )
390
-
391
- step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch = maybe_resume_ckpt(
392
- model, optimizer, device,
393
- )
394
-
395
- # Learnability #4: inform the model of the BOS token id so it can mask
396
- # doc-separator positions in packed sequences. Always set (the mask only
397
- # fires when HYDRA_DOC_SEP_MASK=1 is also on).
398
- if hasattr(model, 'set_bos_token_id'):
399
- model.set_bos_token_id(tokenizer.get_bos_token_id())
400
-
401
- # Learnability #2: EMA shadow copy of weights. AveragedModel clones every
402
- # parameter; we update it after every optimizer step and save it at the
403
- # end alongside the raw checkpoint. Defaults OFF.
404
- ema_model = None
405
- if USE_EMA:
406
- try:
407
- from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
408
- # decay=EMA_DECAY; avg_fn uses get_ema_multi_avg_fn for numerical
409
- # stability across bf16/fp32 mixed parameter groups.
410
- ema_model = AveragedModel(
411
- model,
412
- multi_avg_fn=get_ema_multi_avg_fn(EMA_DECAY),
413
- )
414
- print(f"[EMA] enabled with decay={EMA_DECAY}")
415
- except Exception as _e:
416
- print(f"[EMA] disabled — AveragedModel init failed: {_e}")
417
- ema_model = None
418
-
419
- print("torch.compile: Muon step compiled; AdamW uses torch._fused_adamw_ (model blocks use native CUDA kernels)")
420
-
421
- # Learnability #7: curriculum short-then-long. If enabled, build the
422
- # initial dataloader at the short seq_len; we swap to full MAX_SEQ_LEN
423
- # after CURRICULUM_SHORT_STEPS optimizer steps (see loop below).
424
- _curriculum_active = CURRICULUM_SHORT_STEPS > 0 and CURRICULUM_SHORT_SEQ_LEN < MAX_SEQ_LEN
425
- _current_seq_len = CURRICULUM_SHORT_SEQ_LEN if _curriculum_active else MAX_SEQ_LEN
426
- if _curriculum_active:
427
- print(
428
- f"[CURRICULUM] starting at T={_current_seq_len} for "
429
- f"{CURRICULUM_SHORT_STEPS} steps, then switching to T={MAX_SEQ_LEN}"
430
- )
431
- train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
432
- x, y, epoch = next(train_loader) # prefetch first batch
433
- if resume_epoch > 0:
434
- epoch = max(epoch, resume_epoch)
435
-
436
- print(f"Time budget: {TIME_BUDGET}s")
437
- print(f"Gradient accumulation steps: {grad_accum_steps}")
438
-
439
- # Token→byte LUT for bits-per-byte computation. evaluate_bpb in prepare.py
440
- # uses total_nats / (ln(2) * total_bytes); our live metric needs to match.
441
- # Without this, `bpb = loss/ln(2)` is actually bits-per-TOKEN, which at
442
- # vocab=8192 scales by ~4 and makes live train bpb non-comparable with
443
- # val_bpb (champion 1.279 bpb vs train printing "8.04").
444
- token_bytes = get_token_bytes(device=device)
445
-
446
- # -----------------------------------------------------------------------
447
- # Training loop
448
- # -----------------------------------------------------------------------
449
-
450
- t_start_training = time.time()
451
-
452
- # Async postprocessing run SOM + Hestia on background threads so
453
- # the GPU doesn't idle during their CPU-bound work.
454
- _ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
455
- _som_thread: threading.Thread | None = None
456
- _hestia_thread: threading.Thread | None = None
457
- _hestia_stream: torch.cuda.Stream | None = (
458
- torch.cuda.Stream() if _ASYNC_POSTPROCESS else None
459
- )
460
-
461
- # HYDRA_PROFILE_STEPS=N prints a per-phase cpu/gpu time breakdown for the
462
- # first N steps (and every 100th step thereafter if N<0). Zero overhead
463
- # when disabled. Used to find what's eating CPU budget when GPU should
464
- # be the bottleneck.
465
- _profile_steps = int(os.environ.get("HYDRA_PROFILE_STEPS", "0"))
466
-
467
- while True:
468
- torch.cuda.synchronize()
469
- t0 = time.time()
470
- _prof = _profile_steps and (step < _profile_steps or (_profile_steps < 0 and step % 100 == 0))
471
- _gpu_ms = 0.0
472
- _data_ms = 0.0
473
- for micro_step in range(grad_accum_steps):
474
- if _prof:
475
- torch.cuda.synchronize(); _t_micro = time.time()
476
- if USE_MDLM:
477
- # MDLM path: corrupt y -> x_noised, run model to get full-V logits,
478
- # compute RB weighted CE on masked positions. x (original input) is
479
- # unused in this path the model only sees the noised version of y.
480
- _mask_id = MDLM_MASK_ID if MDLM_MASK_ID >= 0 else (vocab_size - 1)
481
- x_noised, mask_positions, loss_weights = mdlm_masked_forward_process(
482
- y, mask_token_id=_mask_id, alpha_schedule=MDLM_SCHEDULE,
483
- )
484
- with autocast_ctx:
485
- logits = model(x_noised) # targets=None -> (B, T, V) logits
486
- loss = mdlm_rb_loss(logits, y, mask_positions, loss_weights)
487
- else:
488
- with autocast_ctx:
489
- loss = model(x, y)
490
- train_loss = loss.detach()
491
- loss = loss / grad_accum_steps
492
- loss.backward()
493
- if _prof:
494
- torch.cuda.synchronize()
495
- _gpu_ms += (time.time() - _t_micro) * 1000
496
- _t_data = time.time()
497
- x, y, epoch = next(train_loader)
498
- if _prof:
499
- _data_ms += (time.time() - _t_data) * 1000
500
- if _prof:
501
- torch.cuda.synchronize(); _t_fb = time.time()
502
-
503
- # Progress and schedules
504
- progress = min(total_training_time / TIME_BUDGET, 1.0)
505
- lrm = get_lr_multiplier(progress)
506
- muon_momentum = get_muon_momentum(step)
507
- muon_weight_decay = get_weight_decay(progress)
508
- for group in optimizer.param_groups:
509
- group["lr"] = group["initial_lr"] * lrm
510
- if group['kind'] == 'muon':
511
- group["momentum"] = muon_momentum
512
- group["weight_decay"] = muon_weight_decay
513
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
514
- optimizer.step()
515
- if _prof:
516
- torch.cuda.synchronize(); _t_opt = time.time()
517
-
518
- # Learnability #2: EMA update after every optimizer step.
519
- if ema_model is not None:
520
- try:
521
- ema_model.update_parameters(model)
522
- except Exception as _e:
523
- print(f"[EMA] update failed at step {step}: {_e}", flush=True)
524
-
525
- # Learnability #7: curriculum transition. After
526
- # CURRICULUM_SHORT_STEPS optimizer steps, rebuild the dataloader at
527
- # MAX_SEQ_LEN. Done once, then the flag flips off.
528
- if _curriculum_active and step + 1 >= CURRICULUM_SHORT_STEPS:
529
- print(
530
- f"[CURRICULUM] step={step+1} — switching from T={_current_seq_len} "
531
- f"to T={MAX_SEQ_LEN}",
532
- flush=True,
533
- )
534
- _current_seq_len = MAX_SEQ_LEN
535
- _curriculum_active = False
536
- train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
537
- # Prefetch the next batch at the new seq_len so the following
538
- # loop iteration consumes fresh data.
539
- x, y, epoch = next(train_loader)
540
-
541
- # Online SOM update retina is now a plain Python attribute (not a
542
- # registered buffer) so mutations do not invalidate torch.compile guards.
543
- # Runs fully on CPU; safe to overlap with GPU forward pass.
544
- _last_sdr = getattr(model, "_last_sdr", None)
545
- if _last_sdr is not None:
546
- if _ASYNC_POSTPROCESS:
547
- if _som_thread is not None:
548
- _som_thread.join()
549
- # Clone tensors before next step overwrites them
550
- _som_x = x.clone()
551
- _som_sdr = _last_sdr.clone()
552
- _som_thread = threading.Thread(
553
- target=model.sdr_semantic.maybe_som_update,
554
- args=(_som_x, _som_sdr),
555
- daemon=True,
556
- )
557
- _som_thread.start()
558
- else:
559
- model.sdr_semantic.maybe_som_update(x, _last_sdr)
560
-
561
- # Hestia QAT anneal temperature every step, snap every N steps.
562
- # apply_to walks all Linear modules (CPU) then does .data.copy_ (GPU).
563
- # Background thread + separate CUDA stream lets this overlap with
564
- # the next forward pass on the default stream.
565
- _hestia_progress = (time.time() - t_start_training) / max(TIME_BUDGET, 1)
566
- _hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
567
- if step % _hestia_interval == 0:
568
- if _ASYNC_POSTPROCESS:
569
- if _hestia_thread is not None:
570
- _hestia_thread.join()
571
-
572
- def _hestia_bg(mdl: torch.nn.Module, prog: float) -> None:
573
- assert _hestia_stream is not None
574
- with torch.cuda.stream(_hestia_stream):
575
- mdl.hestia.anneal_temperature(prog)
576
- mdl.hestia.apply_to(mdl)
577
-
578
- _hestia_thread = threading.Thread(
579
- target=_hestia_bg,
580
- args=(model, _hestia_progress),
581
- daemon=True,
582
- )
583
- _hestia_thread.start()
584
- else:
585
- model.hestia.anneal_temperature(_hestia_progress)
586
- model.hestia.apply_to(model)
587
- else:
588
- # anneal_temperature is cheap (~1 us), keep inline
589
- model.hestia.anneal_temperature(_hestia_progress)
590
-
591
- model.zero_grad(set_to_none=True)
592
-
593
- train_loss_f = train_loss.item()
594
- if math.isnan(train_loss_f) or train_loss_f > 100:
595
- print("FAIL")
596
- # Save to a DIFFERENT file never clobber a good latest.pt with
597
- # a NaN/diverged state. The good ckpt from the last periodic save
598
- # is the right place to resume from.
599
- save_ckpt(
600
- model,
601
- optimizer,
602
- config,
603
- step,
604
- total_training_time,
605
- smooth_train_loss,
606
- bpt_ema,
607
- epoch,
608
- FAILED_CKPT,
609
- blocking=True,
610
- )
611
- raise SystemExit(1)
612
-
613
- torch.cuda.synchronize()
614
- t1 = time.time()
615
- dt = t1 - t0
616
-
617
- if _prof:
618
- fb = (_t_fb - t0) * 1000
619
- opt = (_t_opt - _t_fb) * 1000
620
- rest = (t1 - _t_opt) * 1000
621
- print(
622
- f"[PROF step={step:05d}] gpu={_gpu_ms:.0f}ms data_fetch={_data_ms:.0f}ms "
623
- f"(sum_fb={fb:.0f}) opt={opt:.0f}ms rest={rest:.0f}ms total={dt*1000:.0f}ms",
624
- flush=True,
625
- )
626
-
627
- if step > 10:
628
- total_training_time += dt
629
-
630
- ema_beta = 0.9
631
- smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
632
- debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1))
633
- pct_done = 100 * progress
634
- tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
635
- mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / GPU_BF16_PEAK_FLOPS
636
- remaining = max(0, TIME_BUDGET - total_training_time)
637
-
638
- # Bytes-per-token for the CURRENT batch. evaluate_bpb in prepare.py
639
- # computes bits-per-BYTE (total_nats / (ln2 * total_bytes)); to match
640
- # that semantics live, we EMA-smooth the per-batch bytes/token and
641
- # divide. Without this, the old `bpb = loss/ln2` was actually
642
- # bits-per-token ~4× larger than val_bpb at vocab=8192 and
643
- # therefore not comparable to the champion 1.279 bpb metric.
644
- with torch.no_grad():
645
- y_flat = y.view(-1)
646
- nbytes_batch = token_bytes[y_flat]
647
- mask = nbytes_batch > 0
648
- mask_count = mask.sum().clamp(min=1).float()
649
- avg_bytes_per_tok = (nbytes_batch.float() * mask.float()).sum() / mask_count
650
- bpt_batch = float(avg_bytes_per_tok.item())
651
- if step == 0 or bpt_ema <= 0.0:
652
- bpt_ema = bpt_batch
653
- else:
654
- bpt_ema = 0.98 * bpt_ema + 0.02 * bpt_batch
655
-
656
- # Dual metric: bpb (byte-normalized, comparable with val_bpb) AND
657
- # bpt (bits per token, the raw loss in bits). bpt_div exposes the
658
- # current avg bytes-per-token so the conversion is transparent.
659
- bpt = debiased_smooth_loss / math.log(2)
660
- bpb = bpt / max(bpt_ema, 1e-6)
661
- vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
662
- current_lr = optimizer.param_groups[0]["lr"]
663
-
664
- # Per-step line-buffered log. NOT \r-overwritten so tee/grep see it.
665
- # Keep key=value pairs grep-friendly.
666
- ppl = 2.0 ** bpb # perplexity (byte-level)
667
- print(
668
- f"step={step:05d} loss={debiased_smooth_loss:.4f} bpb={bpb:.4f} ppl={ppl:.3f} "
669
- f"bpt={bpt:.3f} bpt_div={bpt_ema:.2f} "
670
- f"tps={tok_per_sec} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
671
- f"lr={current_lr:.2e} vram={vram_mib:.0f}MiB "
672
- f"pct={pct_done:.1f} epoch={epoch} remaining={remaining:.0f}s",
673
- flush=True,
674
- )
675
-
676
- if step == 0:
677
- gc.collect()
678
- gc.freeze()
679
- gc.disable()
680
- # No periodic gc.collect() we disabled+froze at step 0 on purpose,
681
- # so a manual collect every 5k steps just re-scans frozen objects
682
- # (burned ~900 ms/event in production) for no live-garbage reason.
683
-
684
- if CKPT_INTERVAL > 0 and step > 0 and step % CKPT_INTERVAL == 0:
685
- save_ckpt(
686
- model,
687
- optimizer,
688
- config,
689
- step,
690
- total_training_time,
691
- smooth_train_loss,
692
- bpt_ema,
693
- epoch,
694
- LATEST_CKPT,
695
- )
696
-
697
- # Periodic mid-training validation so we can see the model learning
698
- # English in real time (not just at the end). Small val batch so it
699
- # doesn't eat significant training time.
700
- mid_val_interval = int(os.environ.get("HYDRA_MID_VAL_INTERVAL", "500"))
701
- if mid_val_interval > 0 and step > 0 and step % mid_val_interval == 0:
702
- model.eval()
703
- try:
704
- # Defrag GPU memory before eval allocates fresh chunks
705
- # without this the eval path can OOM on 6GB cards even
706
- # though total usage fits, because the allocator's free
707
- # blocks are fragmented.
708
- torch.cuda.empty_cache()
709
- _orig_mid = _prepare_mod.EVAL_TOKENS
710
- _prepare_mod.EVAL_TOKENS = 262144 # ~260K tokens, fast
711
- with torch.no_grad():
712
- with autocast_ctx:
713
- mid_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
714
- _prepare_mod.EVAL_TOKENS = _orig_mid
715
- mid_ppl = 2.0 ** mid_bpb
716
- print(f"[MID_VAL] step={step} val_bpb={mid_bpb:.4f} val_ppl={mid_ppl:.3f}", flush=True)
717
-
718
- # Per-layer diagnostic panel. Only printed when HYDRA_LAYER_DIAGNOSTICS=1
719
- # is set (otherwise the layer_* keys are absent from _metrics).
720
- _diag_metrics = model.get_secondary_metrics()
721
- _layer_keys = sorted([k for k in _diag_metrics.keys() if k.startswith('layer_')])
722
- if _layer_keys:
723
- # Condense: one row per layer showing the four core signals.
724
- n_layers = len(model.blocks)
725
- print(f"[LAYER_DIAG] step={step}", flush=True)
726
- for li in range(n_layers):
727
- d_ratio = _diag_metrics.get(f'layer_{li}_delta_ratio', float('nan'))
728
- out_n = _diag_metrics.get(f'layer_{li}_out_norm', float('nan'))
729
- g_norm = _diag_metrics.get(f'layer_{li}_grad_norm', float('nan'))
730
- eff_r = _diag_metrics.get(f'layer_{li}_eff_rank', float('nan'))
731
- f_std = _diag_metrics.get(f'layer_{li}_feat_std', float('nan'))
732
- print(
733
- f"[LAYER_DIAG] L{li:02d} delta_ratio={d_ratio:.4f} "
734
- f"out_norm={out_n:.4f} grad_norm={g_norm:.3e} "
735
- f"eff_rank={eff_r:.1f} feat_std={f_std:.4f}",
736
- flush=True,
737
- )
738
- htm_proj_g = _diag_metrics.get('htm_proj_grad_norm', None)
739
- if htm_proj_g is not None:
740
- print(f"[LAYER_DIAG] htm_proj grad_norm={htm_proj_g:.3e}", flush=True)
741
- except Exception as e:
742
- print(f"[MID_VAL] failed: {e}", flush=True)
743
- model.train()
744
-
745
- step += 1
746
-
747
- if step > 10 and total_training_time >= TIME_BUDGET:
748
- break
749
-
750
- # Drain async postprocessing threads before eval
751
- if _som_thread is not None:
752
- _som_thread.join()
753
- if _hestia_thread is not None:
754
- _hestia_thread.join()
755
- if _hestia_stream is not None:
756
- _hestia_stream.synchronize()
757
-
758
- total_tokens = step * TOTAL_BATCH_SIZE
759
-
760
- # ----------------------------------------------------------------------
761
- # SAVE ORDER (critical):
762
- # 1. Save PRETRAIN_FINAL_CKPT with val_bpb=None (hedge against eval OOM)
763
- # 2. Save LATEST_CKPT with val_bpb=None (hedge against eval OOM)
764
- # 3. Run eval (may OOM on small GPUs; we survive it)
765
- # 4. Re-save both ckpts with val_bpb filled in
766
- # This way we NEVER lose the final trained weights to an eval crash.
767
- # Previous ordering put eval first, so an eval-time OOM destroyed the
768
- # only record of a 6h training run (2026-04-22 incident).
769
- # ----------------------------------------------------------------------
770
-
771
- save_ckpt(
772
- model, optimizer, config, step, total_training_time,
773
- smooth_train_loss, bpt_ema, epoch, PRETRAIN_FINAL_CKPT,
774
- val_bpb=None, blocking=True,
775
- )
776
- save_ckpt(
777
- model, optimizer, config, step, total_training_time,
778
- smooth_train_loss, bpt_ema, epoch, LATEST_CKPT,
779
- val_bpb=None, blocking=True,
780
- )
781
-
782
- # Now it's safe to eval — ckpts are on disk regardless of what happens here.
783
- # HYDRA_EVAL_BATCH overrides DEVICE_BATCH_SIZE (env-tunable; default halves
784
- # the training batch because eval holds activations for full sequence and
785
- # does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
786
- # how many val tokens to sweep (default 2 M, short enough for autoresearch
787
- # 5-min budgets).
788
- val_bpb: float | None = None
789
- _eval_B = int(os.environ.get("HYDRA_EVAL_BATCH", str(max(1, DEVICE_BATCH_SIZE // 2))))
790
- _eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(2 * 524288)))
791
- try:
792
- # Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
793
- # which leaves < 1GB for the eval forward the driver can't satisfy
794
- # the allocation. Free EVERY tensor we don't strictly need:
795
- # - optimizer grads (set_to_none releases tensor)
796
- # - optimizer.state (fp32 Muon NS workspace, AdamW moments ~size-of-params each)
797
- # - model internal caches (HTM subsample cache, SDR stash)
798
- # After this, VRAM should be ~params only (bf16 ≈ 120MB at 60M params).
799
- optimizer.zero_grad(set_to_none=True)
800
- if hasattr(optimizer, 'state') and optimizer.state:
801
- for p, st in list(optimizer.state.items()):
802
- st.clear()
803
- optimizer.state.clear()
804
- for p in model.parameters():
805
- if p.grad is not None:
806
- p.grad = None
807
- if hasattr(model, '_htm_cache'):
808
- model._htm_cache = None
809
- if hasattr(model, '_last_sdr'):
810
- model._last_sdr = None
811
- import gc as _gc
812
- _gc.collect()
813
- torch.cuda.empty_cache()
814
- torch.cuda.synchronize()
815
- try:
816
- _free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
817
- print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
818
- except Exception:
819
- pass
820
- print(f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B}...", flush=True)
821
- model.eval()
822
- _orig = _prepare_mod.EVAL_TOKENS
823
- _prepare_mod.EVAL_TOKENS = _eval_tokens
824
- with autocast_ctx:
825
- val_bpb = evaluate_bpb(model, tokenizer, _eval_B)
826
- _prepare_mod.EVAL_TOKENS = _orig
827
- val_ppl = 2 ** val_bpb
828
- print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
829
- except torch.cuda.OutOfMemoryError as e:
830
- print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
831
- torch.cuda.empty_cache()
832
- except Exception as e:
833
- import traceback as _tb
834
- print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
835
- _tb.print_exc()
836
- try:
837
- _free = torch.cuda.mem_get_info()[0] / 1024 / 1024
838
- print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
839
- except Exception:
840
- pass
841
-
842
- # Final ckpts with val_bpb filled in (if eval succeeded).
843
- save_ckpt(
844
- model, optimizer, config, step, total_training_time,
845
- smooth_train_loss, bpt_ema, epoch, LATEST_CKPT,
846
- val_bpb=val_bpb, blocking=True,
847
- )
848
- save_ckpt(
849
- model, optimizer, config, step, total_training_time,
850
- smooth_train_loss, bpt_ema, epoch, PRETRAIN_FINAL_CKPT,
851
- val_bpb=val_bpb, blocking=True,
852
- )
853
-
854
- # Learnability #2: persist EMA weights alongside the raw checkpoint.
855
- # latest_ema.pt contains ema_model.module (the Averaged params) so it
856
- # can be loaded by evaluation / inference code that expects the same
857
- # state_dict shape as the raw model.
858
- if ema_model is not None:
859
- try:
860
- ema_ckpt_path = CACHE_DIR / "latest_ema.pt"
861
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
862
- torch.save({
863
- "model_state_dict": ema_model.module.state_dict(),
864
- "config": asdict(config),
865
- "step": step,
866
- "epoch": epoch,
867
- "train_seconds": total_training_time,
868
- "val_bpb": val_bpb,
869
- "ema_decay": EMA_DECAY,
870
- }, str(ema_ckpt_path))
871
- print(f"[EMA] saved {ema_ckpt_path} (step={step})", flush=True)
872
- except Exception as _e:
873
- print(f"[EMA] save failed: {_e}", flush=True)
874
-
875
- run_factual_probes(model, tokenizer, device, autocast_ctx)
876
-
877
- t_end = time.time()
878
- startup_time = t_start_training - t_start
879
- steady_state_mfu = (
880
- 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10)
881
- / total_training_time / GPU_BF16_PEAK_FLOPS
882
- if total_training_time > 0 else 0
883
- )
884
- peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
885
- metrics = model.get_secondary_metrics()
886
-
887
- print("---")
888
- print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
889
- print(f"training_seconds: {total_training_time:.1f}")
890
- print(f"total_seconds: {t_end - t_start:.1f}")
891
- print(f"peak_vram_mb: {peak_vram_mb:.1f}")
892
- print(f"mfu_percent: {steady_state_mfu:.2f}")
893
- print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
894
- print(f"num_steps: {step}")
895
- print(f"num_params_M: {num_params / 1e6:.1f}")
896
- print(f"n_layer: {N_LAYER}")
897
- print(f"d_model: {D_MODEL}")
898
- print(f"engram_hit_rate: {metrics.get('engram_hit_rate', 0.0):.4f}")
899
- print(f"sdr_active_bits: {metrics.get('sdr_active_bits', 0):.1f}")
900
- print(f"htm_anomaly: {metrics.get('htm_anomaly', 0):.4f}")
901
-
902
- # Per-layer summary panel — only printed when diagnostics were active.
903
- _layer_keys = sorted([k for k in metrics.keys() if k.startswith('layer_')])
904
- if _layer_keys:
905
- n_layers = len(model.blocks)
906
- print("--- per-layer diagnostic panel ---")
907
- for li in range(n_layers):
908
- d_ratio = metrics.get(f'layer_{li}_delta_ratio', float('nan'))
909
- out_n = metrics.get(f'layer_{li}_out_norm', float('nan'))
910
- g_norm = metrics.get(f'layer_{li}_grad_norm', float('nan'))
911
- eff_r = metrics.get(f'layer_{li}_eff_rank', float('nan'))
912
- f_std = metrics.get(f'layer_{li}_feat_std', float('nan'))
913
- print(
914
- f"L{li:02d} delta_ratio={d_ratio:.4f} out_norm={out_n:.4f} "
915
- f"grad_norm={g_norm:.3e} eff_rank={eff_r:.1f} feat_std={f_std:.4f}"
916
- )
917
-
918
- # Emit full metrics dictionary as JSON for sweep aggregation. Path from
919
- # HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
920
- # written (even without diagnostics) so the aggregator can compare runs.
921
- _metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
922
- try:
923
- _dump = dict(metrics)
924
- _dump.update({
925
- 'val_bpb': float(val_bpb),
926
- 'val_ppl': float(val_ppl),
927
- 'n_layer': int(N_LAYER),
928
- 'd_model': int(D_MODEL),
929
- 'num_params_M': float(num_params / 1e6),
930
- 'num_steps': int(step),
931
- 'total_tokens_M': float(total_tokens / 1e6),
932
- 'peak_vram_mb': float(peak_vram_mb),
933
- 'training_seconds': float(total_training_time),
934
- 'sdr_target_active': int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
935
- })
936
- Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
937
- with open(_metrics_out, 'w') as _f:
938
- json.dump(_dump, _f, indent=2, sort_keys=True)
939
- print(f"[METRICS] wrote {_metrics_out}", flush=True)
940
- # Also emit a single-line JSON to stdout so the sweep aggregator can
941
- # scrape it from HF Jobs logs without pulling files out of the container.
942
- print("[METRICS_JSON] " + json.dumps(_dump, sort_keys=True), flush=True)
943
- except Exception as _e:
944
- print(f"[METRICS] write failed: {_e}", flush=True)
945
-
946
- run_factual_english(model, tokenizer, MAX_SEQ_LEN)
947
- # startup_time is informative but not printed (preserve historical output)
948
- _ = startup_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HYDRA training entry: setup, train loop, eval, summary.
2
+
3
+ Extracted from the monolithic train.py (W1 modularization). Semantics
4
+ preserved. Public entrypoint: `main()`.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import gc
10
+ import json
11
+ import math
12
+ import os
13
+ import sys
14
+ import threading
15
+ import time
16
+ from dataclasses import asdict
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ # Line-buffered stdout so `python -u train.py | tee run.log | grep step` is
22
+ # live (no \r overwrite, no 4k block-buffered pipe stalls). Safe on Python
23
+ # 3.7+ where io.TextIOWrapper.reconfigure exists.
24
+ try:
25
+ sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined]
26
+ except Exception:
27
+ pass
28
+
29
+ from hydra.config import (
30
+ ADAM_BETAS, CURRICULUM_SHORT_SEQ_LEN, CURRICULUM_SHORT_STEPS,
31
+ D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMA_DECAY, EMBEDDING_LR,
32
+ ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
33
+ FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
34
+ N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
35
+ UNEMBEDDING_LR, USE_EMA, WARMUP_RATIO, WEIGHT_DECAY,
36
+ )
37
+ from hydra.diffusion_loss import mdlm_masked_forward_process, mdlm_rb_loss
38
+ from hydra.eval import run_factual_english, run_factual_probes
39
+ from hydra.model import PostSemClawModel
40
+
41
+ import prepare as _prepare_mod
42
+ from prepare import MAX_SEQ_LEN, TIME_BUDGET as _TIME_BUDGET, Tokenizer, evaluate_bpb as _evaluate_bpb_shards, get_token_bytes, make_dataloader as _make_dataloader_shards
43
+
44
+ # Streaming Nemotron path (Super3 recipe). Opt-in via HYDRA_USE_NEMOTRON=1.
45
+ if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
46
+ import prepare_nemotron as _p_nemo
47
+ make_dataloader = _p_nemo.make_dataloader
48
+ evaluate_bpb = _p_nemo.evaluate_bpb
49
+ else:
50
+ make_dataloader = _make_dataloader_shards
51
+ evaluate_bpb = _evaluate_bpb_shards
52
+
53
+ TIME_BUDGET = int(os.environ.get("HYDRA_TIME_BUDGET", str(_TIME_BUDGET)))
54
+ _prepare_mod.TIME_BUDGET = TIME_BUDGET # sync for evaluate_bpb
55
+
56
+ CACHE_DIR = Path.home() / ".cache" / "autoresearch"
57
+ LATEST_CKPT = CACHE_DIR / "latest.pt"
58
+ PRETRAIN_FINAL_CKPT = CACHE_DIR / "pretrain_final.pt"
59
+ FAILED_CKPT = CACHE_DIR / "latest_failed.pt" # crash/FAIL path — never overwrites good
60
+ BEST_CKPT = CACHE_DIR / "best_bpb.pt" # lowest val_bpb seen
61
+ CKPT_INTERVAL = int(os.environ.get("HYDRA_CKPT_INTERVAL", "250"))
62
+ CKPT_ROTATIONS = int(os.environ.get("HYDRA_CKPT_ROTATIONS", "3")) # how many .N backups to keep
63
+ RESUME_CKPT = os.environ.get("HYDRA_RESUME_CKPT", str(LATEST_CKPT))
64
+
65
+ # MDLM (Masked Diffusion LM) Rao-Blackwellized ELBO loss path.
66
+ # HYDRA_USE_MDLM=1 : switch training loss from AR sampled-softmax CE
67
+ # to MDLM RB weighted CE (arXiv:2406.07524).
68
+ # HYDRA_MDLM_MASK_ID=N : token id used for the MASK sentinel (default:
69
+ # last valid id, vocab_size - 1). Ensure this id
70
+ # never appears in training targets — typical
71
+ # practice is to reserve it.
72
+ # HYDRA_MDLM_SCHEDULE=loglinear|linear : noise schedule (default loglinear).
73
+ # When enabled, the per-step flow is:
74
+ # 1. mdlm_masked_forward_process(y) -> (x_noised, mask_positions, weights)
75
+ # 2. logits = model(x_noised) (no targets -> full V logits)
76
+ # 3. loss = mdlm_rb_loss(logits, y, mask_positions, weights)
77
+ # Sampled-softmax is bypassed in this path because the RB ELBO needs
78
+ # full-vocab logits on masked positions.
79
+ USE_MDLM = os.environ.get("HYDRA_USE_MDLM", "0") == "1"
80
+ MDLM_MASK_ID = int(os.environ.get("HYDRA_MDLM_MASK_ID", "-1")) # -1 => default to vocab_size-1 at runtime
81
+ MDLM_SCHEDULE = os.environ.get("HYDRA_MDLM_SCHEDULE", "loglinear")
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Schedules
86
+ # ---------------------------------------------------------------------------
87
+
88
+ def get_lr_multiplier(progress: float) -> float:
89
+ if progress < WARMUP_RATIO:
90
+ return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
91
+ decay_progress = (progress - WARMUP_RATIO) / (1.0 - WARMUP_RATIO)
92
+ return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (1 + math.cos(math.pi * decay_progress))
93
+
94
+
95
+ def get_muon_momentum(step: int) -> float:
96
+ frac = min(step / 300, 1)
97
+ return (1 - frac) * 0.85 + frac * 0.95
98
+
99
+
100
+ def get_weight_decay(progress: float) -> float:
101
+ return WEIGHT_DECAY * (1 - progress)
102
+
103
+
104
+ _CKPT_WORKER_THREAD: threading.Thread | None = None
105
+
106
+
107
+ def _ckpt_snapshot_state_dicts(
108
+ model: PostSemClawModel,
109
+ optimizer: torch.optim.Optimizer,
110
+ ) -> tuple[dict, dict]:
111
+ """Detach + CPU-clone every tensor so a bg thread can serialize safely
112
+ while the main loop keeps mutating live weights/optimizer state."""
113
+ msd = {k: (v.detach().to("cpu", copy=True) if torch.is_tensor(v) else v)
114
+ for k, v in model.state_dict().items()}
115
+ # optimizer.state_dict() is a nested dict; walk it.
116
+ osd_raw = optimizer.state_dict()
117
+
118
+ def _to_cpu(obj):
119
+ if torch.is_tensor(obj):
120
+ return obj.detach().to("cpu", copy=True)
121
+ if isinstance(obj, dict):
122
+ return {k: _to_cpu(v) for k, v in obj.items()}
123
+ if isinstance(obj, list):
124
+ return [_to_cpu(v) for v in obj]
125
+ if isinstance(obj, tuple):
126
+ return tuple(_to_cpu(v) for v in obj)
127
+ return obj
128
+
129
+ osd = _to_cpu(osd_raw)
130
+ return msd, osd
131
+
132
+
133
+ def save_ckpt(
134
+ model: PostSemClawModel,
135
+ optimizer: torch.optim.Optimizer,
136
+ config: PostSemClawConfig,
137
+ step: int,
138
+ total_training_time: float,
139
+ smooth_train_loss: float,
140
+ bpt_ema: float,
141
+ epoch: int,
142
+ path: Path,
143
+ *,
144
+ val_bpb: float | None = None,
145
+ blocking: bool = False,
146
+ ) -> None:
147
+ """Save a training checkpoint.
148
+
149
+ Default behavior is async: the GPU→CPU state_dict clone runs on the main
150
+ thread (unavoidable; needs to happen before the next optimizer.step that
151
+ mutates live weights), then `torch.save` is dispatched to a daemon
152
+ worker thread. The next call joins any still-running prior save so only
153
+ one disk write is in flight.
154
+
155
+ `blocking=True` restores the original synchronous behavior — used for
156
+ end-of-training saves where correctness on process exit matters.
157
+ """
158
+ global _CKPT_WORKER_THREAD
159
+ try:
160
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
161
+ msd, osd = _ckpt_snapshot_state_dicts(model, optimizer)
162
+ # asdict() recursively converts dataclass fields to a dict and
163
+ # renders tuples as lists. hyena_layers therefore round-trips as a
164
+ # JSON-safe list; config_from_dict normalizes it back to a tuple.
165
+ payload = {
166
+ "model_state_dict": msd,
167
+ "optimizer_state_dict": osd,
168
+ "config": asdict(config),
169
+ "step": step,
170
+ "epoch": epoch,
171
+ "train_seconds": total_training_time,
172
+ "smoothed_loss": smooth_train_loss,
173
+ "bpt_ema": bpt_ema,
174
+ "val_bpb": val_bpb,
175
+ }
176
+ path_str = str(path)
177
+
178
+ def _rotate(p: str) -> None:
179
+ """Keep up to CKPT_ROTATIONS previous versions as p.1, p.2, ..."""
180
+ if CKPT_ROTATIONS <= 0:
181
+ return
182
+ try:
183
+ # Walk from oldest to newest so we don't clobber newer with older.
184
+ for i in range(CKPT_ROTATIONS, 0, -1):
185
+ src = f"{p}.{i-1}" if i > 1 else p
186
+ dst = f"{p}.{i}"
187
+ if os.path.exists(src):
188
+ os.replace(src, dst)
189
+ except Exception as e:
190
+ # Rotation is best-effort; never block a save on it.
191
+ print(f"[ckpt] rotate warn {p}: {type(e).__name__}: {e}", flush=True)
192
+
193
+ def _write():
194
+ try:
195
+ _rotate(path_str)
196
+ tmp = path_str + ".tmp"
197
+ torch.save(payload, tmp)
198
+ os.replace(tmp, path_str)
199
+ print(f"[ckpt] saved {path_str} (step={step})", flush=True)
200
+ except Exception as e:
201
+ print(f"[ckpt] SAVE FAILED {path_str}: {type(e).__name__}: {e}", flush=True)
202
+
203
+ if blocking:
204
+ _write()
205
+ return
206
+
207
+ # Join previous writer so at most one torch.save runs at a time.
208
+ if _CKPT_WORKER_THREAD is not None and _CKPT_WORKER_THREAD.is_alive():
209
+ _CKPT_WORKER_THREAD.join()
210
+ _CKPT_WORKER_THREAD = threading.Thread(
211
+ target=_write, daemon=True, name=f"ckpt-save-{step}"
212
+ )
213
+ _CKPT_WORKER_THREAD.start()
214
+ # Non-default checkpoint paths are usually tests or one-off utilities that
215
+ # expect save_ckpt() to be durable when it returns. Keep the hot training
216
+ # path async for CACHE_DIR checkpoints, but make explicit custom paths
217
+ # deterministic.
218
+ if path.parent.resolve() != CACHE_DIR.resolve():
219
+ _CKPT_WORKER_THREAD.join()
220
+ except Exception as e:
221
+ print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
222
+
223
+
224
+ def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
225
+ """Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
226
+
227
+ Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
228
+ older checkpoints, and list-ified tuples are coerced back to tuples so
229
+ the dataclass keeps its declared types.
230
+
231
+ This is the ckpt-safe inverse of `asdict(config)` used by save_ckpt and
232
+ guarantees that a resume path can rebuild the exact same model topology
233
+ (Mamba3 vs HyenaBlock per layer) regardless of env-var state at resume.
234
+ """
235
+ # Only keep keys that are actually declared on PostSemClawConfig — extra
236
+ # keys in older/newer checkpoints must not crash construction.
237
+ field_names = {f.name for f in PostSemClawConfig.__dataclass_fields__.values()}
238
+ filtered = {k: v for k, v in cfg_dict.items() if k in field_names}
239
+ # asdict renders tuple[int,...] as list[int]; coerce back so the model
240
+ # builder sees the declared type.
241
+ if "hyena_layers" in filtered and filtered["hyena_layers"] is not None:
242
+ filtered["hyena_layers"] = tuple(sorted(int(x) for x in filtered["hyena_layers"]))
243
+ return PostSemClawConfig(**filtered)
244
+
245
+
246
+ def _try_load_ckpt(path: Path, model, optimizer, device):
247
+ """Attempt to load a single ckpt. Returns the tuple on success, None on any failure."""
248
+ if not path.exists():
249
+ return None
250
+ ckpt = torch.load(str(path), map_location=device, weights_only=False)
251
+ state = ckpt.get("model_state_dict", ckpt)
252
+ missing, unexpected = model.load_state_dict(state, strict=False)
253
+ if missing:
254
+ print(f"[ckpt] {path.name} missing={len(missing)}", flush=True)
255
+ if unexpected:
256
+ print(f"[ckpt] {path.name} unexpected={len(unexpected)}", flush=True)
257
+ optimizer_state = ckpt.get("optimizer_state_dict")
258
+ if optimizer_state is not None:
259
+ try:
260
+ optimizer.load_state_dict(optimizer_state)
261
+ except Exception as e:
262
+ print(f"[ckpt] optimizer restore failed from {path.name}: {type(e).__name__}: {e}", flush=True)
263
+ step = int(ckpt.get("step", 0))
264
+ total_training_time = float(ckpt.get("train_seconds", 0.0))
265
+ smooth_train_loss = float(ckpt.get("smoothed_loss", 0.0))
266
+ bpt_ema = float(ckpt.get("bpt_ema", 0.0))
267
+ epoch = int(ckpt.get("epoch", 0))
268
+ print(
269
+ f"[ckpt] resumed {path} step={step} train_seconds={total_training_time:.1f}",
270
+ flush=True,
271
+ )
272
+ # Warn if resuming a schedule-exhausted ckpt user is probably warm-starting.
273
+ budget = float(os.environ.get("HYDRA_TIME_BUDGET", "0") or 0)
274
+ if budget and total_training_time >= 0.99 * budget:
275
+ print(
276
+ f"[ckpt] WARNING: resumed ckpt used {total_training_time:.0f}s of {budget:.0f}s "
277
+ f"budget. LR schedule is essentially exhausted. "
278
+ f"Set HYDRA_WARMSTART=1 to reset optimizer + scheduler and keep only weights.",
279
+ flush=True,
280
+ )
281
+ return step, total_training_time, smooth_train_loss, bpt_ema, epoch
282
+
283
+
284
+ def maybe_resume_ckpt(
285
+ model: PostSemClawModel,
286
+ optimizer: torch.optim.Optimizer,
287
+ device: torch.device,
288
+ ) -> tuple[int, float, float, float, int]:
289
+ if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
290
+ print("[ckpt] resume disabled; starting fresh", flush=True)
291
+ return 0, 0.0, 0.0, 0.0, 0
292
+
293
+ resume_path = Path(os.path.expanduser(RESUME_CKPT))
294
+ # Try the primary path, then rotated backups. This is crucial because a
295
+ # partial / killed torch.save on the primary path would leave a corrupt
296
+ # file. If that fails we fall back to latest.pt.1, .2, .3 automatically.
297
+ candidates: list[Path] = [resume_path]
298
+ for i in range(1, CKPT_ROTATIONS + 1):
299
+ candidates.append(Path(str(resume_path) + f".{i}"))
300
+
301
+ for cand in candidates:
302
+ if not cand.exists():
303
+ continue
304
+ try:
305
+ result = _try_load_ckpt(cand, model, optimizer, device)
306
+ if result is not None:
307
+ if cand != resume_path:
308
+ print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
309
+ return result
310
+ except Exception as e:
311
+ print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
312
+ continue
313
+
314
+ print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
315
+ return 0, 0.0, 0.0, 0.0, 0
316
+
317
+
318
+ # ---------------------------------------------------------------------------
319
+ # Main entry
320
+ # ---------------------------------------------------------------------------
321
+
322
+ def main() -> None:
323
+ t_start = time.time()
324
+ torch.manual_seed(SEED)
325
+ torch.cuda.manual_seed(SEED)
326
+ # Precision / kernel-selection knobs for peak throughput on Ampere.
327
+ # - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
328
+ # - allow_tf32 : explicit for both matmul + cudnn paths
329
+ # - cudnn.benchmark : env-gated (HYDRA_CUDNN_BENCHMARK, default OFF).
330
+ # TRUE can lock in a locally-better-but-globally-slower algorithm
331
+ # after the autotune phase ends, causing tps to degrade 15-20%
332
+ # over the first ~100 steps. Observed 2026-04-22 and confirmed by
333
+ # differential profiling. Default is now FALSE; set =1 only if you
334
+ # see a specific workload where benchmark helps sustained tps.
335
+ torch.set_float32_matmul_precision("high")
336
+ torch.backends.cuda.matmul.allow_tf32 = True
337
+ torch.backends.cudnn.allow_tf32 = True
338
+ torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
339
+ device = torch.device("cuda")
340
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
341
+
342
+ # Streaming path skips prepare.py (which normally trains the tokenizer
343
+ # and builds the retina), so we must materialize both before model init.
344
  if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
345
  _p_nemo.ensure_tokenizer()
346
+ # Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
347
+ # returns in seconds; otherwise build_retina streams Nemotron docs to
348
+ # compute cooccurrence + train SOM, then uploads back to the cache.
349
+ import subsystems.sdr_retina as _sdr_retina
350
+ _sdr_retina.build_retina()
351
+ tokenizer = Tokenizer.from_directory()
352
+ vocab_size = tokenizer.get_vocab_size()
353
+ print(f"Vocab size: {vocab_size:,}")
354
+
355
+ config = PostSemClawConfig(
356
+ sequence_len=MAX_SEQ_LEN,
357
+ vocab_size=vocab_size,
358
+ n_layer=N_LAYER,
359
+ d_model=D_MODEL,
360
+ d_state=D_STATE,
361
+ headdim=HEADDIM,
362
+ n_heads=N_HEADS,
363
+ expand=EXPAND,
364
+ engram_n_columns=ENGRAM_N_COLUMNS,
365
+ engram_key_dim=ENGRAM_KEY_DIM,
366
+ engram_layer_idx=ENGRAM_LAYER_IDX,
367
+ )
368
+ print(f"Model config: {asdict(config)}")
369
+
370
+ with torch.device("meta"):
371
+ model = PostSemClawModel(config)
372
+ model.to_empty(device=device)
373
+ model.init_weights()
374
+
375
+ param_counts = model.num_scaling_params()
376
+ print("Parameter counts:")
377
+ for key, value in param_counts.items():
378
+ print(f" {key:24s}: {value:,}")
379
+ num_params = param_counts['total']
380
+ num_flops_per_token = model.estimate_flops()
381
+ print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
382
+
383
+ tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
384
+ assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
385
+ grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
386
+
387
+ optimizer = model.setup_optimizer(
388
+ unembedding_lr=UNEMBEDDING_LR,
389
+ embedding_lr=EMBEDDING_LR,
390
+ scalar_lr=SCALAR_LR,
391
+ adam_betas=ADAM_BETAS,
392
+ matrix_lr=MATRIX_LR,
393
+ weight_decay=WEIGHT_DECAY,
394
+ )
395
+
396
+ step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch = maybe_resume_ckpt(
397
+ model, optimizer, device,
398
+ )
399
+
400
+ # Learnability #4: inform the model of the BOS token id so it can mask
401
+ # doc-separator positions in packed sequences. Always set (the mask only
402
+ # fires when HYDRA_DOC_SEP_MASK=1 is also on).
403
+ if hasattr(model, 'set_bos_token_id'):
404
+ model.set_bos_token_id(tokenizer.get_bos_token_id())
405
+
406
+ # Learnability #2: EMA shadow copy of weights. AveragedModel clones every
407
+ # parameter; we update it after every optimizer step and save it at the
408
+ # end alongside the raw checkpoint. Defaults OFF.
409
+ ema_model = None
410
+ if USE_EMA:
411
+ try:
412
+ from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
413
+ # decay=EMA_DECAY; avg_fn uses get_ema_multi_avg_fn for numerical
414
+ # stability across bf16/fp32 mixed parameter groups.
415
+ ema_model = AveragedModel(
416
+ model,
417
+ multi_avg_fn=get_ema_multi_avg_fn(EMA_DECAY),
418
+ )
419
+ print(f"[EMA] enabled with decay={EMA_DECAY}")
420
+ except Exception as _e:
421
+ print(f"[EMA] disabled AveragedModel init failed: {_e}")
422
+ ema_model = None
423
+
424
+ print("torch.compile: Muon step compiled; AdamW uses torch._fused_adamw_ (model blocks use native CUDA kernels)")
425
+
426
+ # Learnability #7: curriculum short-then-long. If enabled, build the
427
+ # initial dataloader at the short seq_len; we swap to full MAX_SEQ_LEN
428
+ # after CURRICULUM_SHORT_STEPS optimizer steps (see loop below).
429
+ _curriculum_active = CURRICULUM_SHORT_STEPS > 0 and CURRICULUM_SHORT_SEQ_LEN < MAX_SEQ_LEN
430
+ _current_seq_len = CURRICULUM_SHORT_SEQ_LEN if _curriculum_active else MAX_SEQ_LEN
431
+ if _curriculum_active:
432
+ print(
433
+ f"[CURRICULUM] starting at T={_current_seq_len} for "
434
+ f"{CURRICULUM_SHORT_STEPS} steps, then switching to T={MAX_SEQ_LEN}"
435
+ )
436
+ train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
437
+ x, y, epoch = next(train_loader) # prefetch first batch
438
+ if resume_epoch > 0:
439
+ epoch = max(epoch, resume_epoch)
440
+
441
+ print(f"Time budget: {TIME_BUDGET}s")
442
+ print(f"Gradient accumulation steps: {grad_accum_steps}")
443
+
444
+ # Token→byte LUT for bits-per-byte computation. evaluate_bpb in prepare.py
445
+ # uses total_nats / (ln(2) * total_bytes); our live metric needs to match.
446
+ # Without this, `bpb = loss/ln(2)` is actually bits-per-TOKEN, which at
447
+ # vocab=8192 scales by ~4 and makes live train bpb non-comparable with
448
+ # val_bpb (champion 1.279 bpb vs train printing "8.04").
449
+ token_bytes = get_token_bytes(device=device)
450
+
451
+ # -----------------------------------------------------------------------
452
+ # Training loop
453
+ # -----------------------------------------------------------------------
454
+
455
+ t_start_training = time.time()
456
+
457
+ # Async postprocessing — run SOM + Hestia on background threads so
458
+ # the GPU doesn't idle during their CPU-bound work.
459
+ _ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
460
+ _som_thread: threading.Thread | None = None
461
+ _hestia_thread: threading.Thread | None = None
462
+ _hestia_stream: torch.cuda.Stream | None = (
463
+ torch.cuda.Stream() if _ASYNC_POSTPROCESS else None
464
+ )
465
+
466
+ # HYDRA_PROFILE_STEPS=N prints a per-phase cpu/gpu time breakdown for the
467
+ # first N steps (and every 100th step thereafter if N<0). Zero overhead
468
+ # when disabled. Used to find what's eating CPU budget when GPU should
469
+ # be the bottleneck.
470
+ _profile_steps = int(os.environ.get("HYDRA_PROFILE_STEPS", "0"))
471
+
472
+ while True:
473
+ torch.cuda.synchronize()
474
+ t0 = time.time()
475
+ _prof = _profile_steps and (step < _profile_steps or (_profile_steps < 0 and step % 100 == 0))
476
+ _gpu_ms = 0.0
477
+ _data_ms = 0.0
478
+ for micro_step in range(grad_accum_steps):
479
+ if _prof:
480
+ torch.cuda.synchronize(); _t_micro = time.time()
481
+ if USE_MDLM:
482
+ # MDLM path: corrupt y -> x_noised, run model to get full-V logits,
483
+ # compute RB weighted CE on masked positions. x (original input) is
484
+ # unused in this path the model only sees the noised version of y.
485
+ _mask_id = MDLM_MASK_ID if MDLM_MASK_ID >= 0 else (vocab_size - 1)
486
+ x_noised, mask_positions, loss_weights = mdlm_masked_forward_process(
487
+ y, mask_token_id=_mask_id, alpha_schedule=MDLM_SCHEDULE,
488
+ )
489
+ with autocast_ctx:
490
+ logits = model(x_noised) # targets=None -> (B, T, V) logits
491
+ loss = mdlm_rb_loss(logits, y, mask_positions, loss_weights)
492
+ else:
493
+ with autocast_ctx:
494
+ loss = model(x, y)
495
+ train_loss = loss.detach()
496
+ loss = loss / grad_accum_steps
497
+ loss.backward()
498
+ if _prof:
499
+ torch.cuda.synchronize()
500
+ _gpu_ms += (time.time() - _t_micro) * 1000
501
+ _t_data = time.time()
502
+ x, y, epoch = next(train_loader)
503
+ if _prof:
504
+ _data_ms += (time.time() - _t_data) * 1000
505
+ if _prof:
506
+ torch.cuda.synchronize(); _t_fb = time.time()
507
+
508
+ # Progress and schedules
509
+ progress = min(total_training_time / TIME_BUDGET, 1.0)
510
+ lrm = get_lr_multiplier(progress)
511
+ muon_momentum = get_muon_momentum(step)
512
+ muon_weight_decay = get_weight_decay(progress)
513
+ for group in optimizer.param_groups:
514
+ group["lr"] = group["initial_lr"] * lrm
515
+ if group['kind'] == 'muon':
516
+ group["momentum"] = muon_momentum
517
+ group["weight_decay"] = muon_weight_decay
518
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
519
+ optimizer.step()
520
+ if _prof:
521
+ torch.cuda.synchronize(); _t_opt = time.time()
522
+
523
+ # Learnability #2: EMA update after every optimizer step.
524
+ if ema_model is not None:
525
+ try:
526
+ ema_model.update_parameters(model)
527
+ except Exception as _e:
528
+ print(f"[EMA] update failed at step {step}: {_e}", flush=True)
529
+
530
+ # Learnability #7: curriculum transition. After
531
+ # CURRICULUM_SHORT_STEPS optimizer steps, rebuild the dataloader at
532
+ # MAX_SEQ_LEN. Done once, then the flag flips off.
533
+ if _curriculum_active and step + 1 >= CURRICULUM_SHORT_STEPS:
534
+ print(
535
+ f"[CURRICULUM] step={step+1} — switching from T={_current_seq_len} "
536
+ f"to T={MAX_SEQ_LEN}",
537
+ flush=True,
538
+ )
539
+ _current_seq_len = MAX_SEQ_LEN
540
+ _curriculum_active = False
541
+ train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
542
+ # Prefetch the next batch at the new seq_len so the following
543
+ # loop iteration consumes fresh data.
544
+ x, y, epoch = next(train_loader)
545
+
546
+ # Online SOM update — retina is now a plain Python attribute (not a
547
+ # registered buffer) so mutations do not invalidate torch.compile guards.
548
+ # Runs fully on CPU; safe to overlap with GPU forward pass.
549
+ _last_sdr = getattr(model, "_last_sdr", None)
550
+ if _last_sdr is not None:
551
+ if _ASYNC_POSTPROCESS:
552
+ if _som_thread is not None:
553
+ _som_thread.join()
554
+ # Clone tensors before next step overwrites them
555
+ _som_x = x.clone()
556
+ _som_sdr = _last_sdr.clone()
557
+ _som_thread = threading.Thread(
558
+ target=model.sdr_semantic.maybe_som_update,
559
+ args=(_som_x, _som_sdr),
560
+ daemon=True,
561
+ )
562
+ _som_thread.start()
563
+ else:
564
+ model.sdr_semantic.maybe_som_update(x, _last_sdr)
565
+
566
+ # Hestia QAT — anneal temperature every step, snap every N steps.
567
+ # apply_to walks all Linear modules (CPU) then does .data.copy_ (GPU).
568
+ # Background thread + separate CUDA stream lets this overlap with
569
+ # the next forward pass on the default stream.
570
+ _hestia_progress = (time.time() - t_start_training) / max(TIME_BUDGET, 1)
571
+ _hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
572
+ if step % _hestia_interval == 0:
573
+ if _ASYNC_POSTPROCESS:
574
+ if _hestia_thread is not None:
575
+ _hestia_thread.join()
576
+
577
+ def _hestia_bg(mdl: torch.nn.Module, prog: float) -> None:
578
+ assert _hestia_stream is not None
579
+ with torch.cuda.stream(_hestia_stream):
580
+ mdl.hestia.anneal_temperature(prog)
581
+ mdl.hestia.apply_to(mdl)
582
+
583
+ _hestia_thread = threading.Thread(
584
+ target=_hestia_bg,
585
+ args=(model, _hestia_progress),
586
+ daemon=True,
587
+ )
588
+ _hestia_thread.start()
589
+ else:
590
+ model.hestia.anneal_temperature(_hestia_progress)
591
+ model.hestia.apply_to(model)
592
+ else:
593
+ # anneal_temperature is cheap (~1 us), keep inline
594
+ model.hestia.anneal_temperature(_hestia_progress)
595
+
596
+ model.zero_grad(set_to_none=True)
597
+
598
+ train_loss_f = train_loss.item()
599
+ if math.isnan(train_loss_f) or train_loss_f > 100:
600
+ print("FAIL")
601
+ # Save to a DIFFERENT file — never clobber a good latest.pt with
602
+ # a NaN/diverged state. The good ckpt from the last periodic save
603
+ # is the right place to resume from.
604
+ save_ckpt(
605
+ model,
606
+ optimizer,
607
+ config,
608
+ step,
609
+ total_training_time,
610
+ smooth_train_loss,
611
+ bpt_ema,
612
+ epoch,
613
+ FAILED_CKPT,
614
+ blocking=True,
615
+ )
616
+ raise SystemExit(1)
617
+
618
+ torch.cuda.synchronize()
619
+ t1 = time.time()
620
+ dt = t1 - t0
621
+
622
+ if _prof:
623
+ fb = (_t_fb - t0) * 1000
624
+ opt = (_t_opt - _t_fb) * 1000
625
+ rest = (t1 - _t_opt) * 1000
626
+ print(
627
+ f"[PROF step={step:05d}] gpu={_gpu_ms:.0f}ms data_fetch={_data_ms:.0f}ms "
628
+ f"(sum_fb={fb:.0f}) opt={opt:.0f}ms rest={rest:.0f}ms total={dt*1000:.0f}ms",
629
+ flush=True,
630
+ )
631
+
632
+ if step > 10:
633
+ total_training_time += dt
634
+
635
+ ema_beta = 0.9
636
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
637
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1))
638
+ pct_done = 100 * progress
639
+ tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
640
+ mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / GPU_BF16_PEAK_FLOPS
641
+ remaining = max(0, TIME_BUDGET - total_training_time)
642
+
643
+ # Bytes-per-token for the CURRENT batch. evaluate_bpb in prepare.py
644
+ # computes bits-per-BYTE (total_nats / (ln2 * total_bytes)); to match
645
+ # that semantics live, we EMA-smooth the per-batch bytes/token and
646
+ # divide. Without this, the old `bpb = loss/ln2` was actually
647
+ # bits-per-token ~4× larger than val_bpb at vocab=8192 and
648
+ # therefore not comparable to the champion 1.279 bpb metric.
649
+ with torch.no_grad():
650
+ y_flat = y.view(-1)
651
+ nbytes_batch = token_bytes[y_flat]
652
+ mask = nbytes_batch > 0
653
+ mask_count = mask.sum().clamp(min=1).float()
654
+ avg_bytes_per_tok = (nbytes_batch.float() * mask.float()).sum() / mask_count
655
+ bpt_batch = float(avg_bytes_per_tok.item())
656
+ if step == 0 or bpt_ema <= 0.0:
657
+ bpt_ema = bpt_batch
658
+ else:
659
+ bpt_ema = 0.98 * bpt_ema + 0.02 * bpt_batch
660
+
661
+ # Dual metric: bpb (byte-normalized, comparable with val_bpb) AND
662
+ # bpt (bits per token, the raw loss in bits). bpt_div exposes the
663
+ # current avg bytes-per-token so the conversion is transparent.
664
+ bpt = debiased_smooth_loss / math.log(2)
665
+ bpb = bpt / max(bpt_ema, 1e-6)
666
+ vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
667
+ current_lr = optimizer.param_groups[0]["lr"]
668
+
669
+ # Per-step line-buffered log. NOT \r-overwritten so tee/grep see it.
670
+ # Keep key=value pairs grep-friendly.
671
+ ppl = 2.0 ** bpb # perplexity (byte-level)
672
+ print(
673
+ f"step={step:05d} loss={debiased_smooth_loss:.4f} bpb={bpb:.4f} ppl={ppl:.3f} "
674
+ f"bpt={bpt:.3f} bpt_div={bpt_ema:.2f} "
675
+ f"tps={tok_per_sec} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
676
+ f"lr={current_lr:.2e} vram={vram_mib:.0f}MiB "
677
+ f"pct={pct_done:.1f} epoch={epoch} remaining={remaining:.0f}s",
678
+ flush=True,
679
+ )
680
+
681
+ if step == 0:
682
+ gc.collect()
683
+ gc.freeze()
684
+ gc.disable()
685
+ # No periodic gc.collect() — we disabled+froze at step 0 on purpose,
686
+ # so a manual collect every 5k steps just re-scans frozen objects
687
+ # (burned ~900 ms/event in production) for no live-garbage reason.
688
+
689
+ if CKPT_INTERVAL > 0 and step > 0 and step % CKPT_INTERVAL == 0:
690
+ save_ckpt(
691
+ model,
692
+ optimizer,
693
+ config,
694
+ step,
695
+ total_training_time,
696
+ smooth_train_loss,
697
+ bpt_ema,
698
+ epoch,
699
+ LATEST_CKPT,
700
+ )
701
+
702
+ # Periodic mid-training validation so we can see the model learning
703
+ # English in real time (not just at the end). Small val batch so it
704
+ # doesn't eat significant training time.
705
+ mid_val_interval = int(os.environ.get("HYDRA_MID_VAL_INTERVAL", "500"))
706
+ if mid_val_interval > 0 and step > 0 and step % mid_val_interval == 0:
707
+ model.eval()
708
+ try:
709
+ # Defrag GPU memory before eval allocates fresh chunks —
710
+ # without this the eval path can OOM on 6GB cards even
711
+ # though total usage fits, because the allocator's free
712
+ # blocks are fragmented.
713
+ torch.cuda.empty_cache()
714
+ _orig_mid = _prepare_mod.EVAL_TOKENS
715
+ _prepare_mod.EVAL_TOKENS = 262144 # ~260K tokens, fast
716
+ with torch.no_grad():
717
+ with autocast_ctx:
718
+ mid_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
719
+ _prepare_mod.EVAL_TOKENS = _orig_mid
720
+ mid_ppl = 2.0 ** mid_bpb
721
+ print(f"[MID_VAL] step={step} val_bpb={mid_bpb:.4f} val_ppl={mid_ppl:.3f}", flush=True)
722
+
723
+ # Per-layer diagnostic panel. Only printed when HYDRA_LAYER_DIAGNOSTICS=1
724
+ # is set (otherwise the layer_* keys are absent from _metrics).
725
+ _diag_metrics = model.get_secondary_metrics()
726
+ _layer_keys = sorted([k for k in _diag_metrics.keys() if k.startswith('layer_')])
727
+ if _layer_keys:
728
+ # Condense: one row per layer showing the four core signals.
729
+ n_layers = len(model.blocks)
730
+ print(f"[LAYER_DIAG] step={step}", flush=True)
731
+ for li in range(n_layers):
732
+ d_ratio = _diag_metrics.get(f'layer_{li}_delta_ratio', float('nan'))
733
+ out_n = _diag_metrics.get(f'layer_{li}_out_norm', float('nan'))
734
+ g_norm = _diag_metrics.get(f'layer_{li}_grad_norm', float('nan'))
735
+ eff_r = _diag_metrics.get(f'layer_{li}_eff_rank', float('nan'))
736
+ f_std = _diag_metrics.get(f'layer_{li}_feat_std', float('nan'))
737
+ print(
738
+ f"[LAYER_DIAG] L{li:02d} delta_ratio={d_ratio:.4f} "
739
+ f"out_norm={out_n:.4f} grad_norm={g_norm:.3e} "
740
+ f"eff_rank={eff_r:.1f} feat_std={f_std:.4f}",
741
+ flush=True,
742
+ )
743
+ htm_proj_g = _diag_metrics.get('htm_proj_grad_norm', None)
744
+ if htm_proj_g is not None:
745
+ print(f"[LAYER_DIAG] htm_proj grad_norm={htm_proj_g:.3e}", flush=True)
746
+ except Exception as e:
747
+ print(f"[MID_VAL] failed: {e}", flush=True)
748
+ model.train()
749
+
750
+ step += 1
751
+
752
+ if step > 10 and total_training_time >= TIME_BUDGET:
753
+ break
754
+
755
+ # Drain async postprocessing threads before eval
756
+ if _som_thread is not None:
757
+ _som_thread.join()
758
+ if _hestia_thread is not None:
759
+ _hestia_thread.join()
760
+ if _hestia_stream is not None:
761
+ _hestia_stream.synchronize()
762
+
763
+ total_tokens = step * TOTAL_BATCH_SIZE
764
+
765
+ # ----------------------------------------------------------------------
766
+ # SAVE ORDER (critical):
767
+ # 1. Save PRETRAIN_FINAL_CKPT with val_bpb=None (hedge against eval OOM)
768
+ # 2. Save LATEST_CKPT with val_bpb=None (hedge against eval OOM)
769
+ # 3. Run eval (may OOM on small GPUs; we survive it)
770
+ # 4. Re-save both ckpts with val_bpb filled in
771
+ # This way we NEVER lose the final trained weights to an eval crash.
772
+ # Previous ordering put eval first, so an eval-time OOM destroyed the
773
+ # only record of a 6h training run (2026-04-22 incident).
774
+ # ----------------------------------------------------------------------
775
+
776
+ save_ckpt(
777
+ model, optimizer, config, step, total_training_time,
778
+ smooth_train_loss, bpt_ema, epoch, PRETRAIN_FINAL_CKPT,
779
+ val_bpb=None, blocking=True,
780
+ )
781
+ save_ckpt(
782
+ model, optimizer, config, step, total_training_time,
783
+ smooth_train_loss, bpt_ema, epoch, LATEST_CKPT,
784
+ val_bpb=None, blocking=True,
785
+ )
786
+
787
+ # Now it's safe to eval — ckpts are on disk regardless of what happens here.
788
+ # HYDRA_EVAL_BATCH overrides DEVICE_BATCH_SIZE (env-tunable; default halves
789
+ # the training batch because eval holds activations for full sequence and
790
+ # does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
791
+ # how many val tokens to sweep (default 2 M, short enough for autoresearch
792
+ # 5-min budgets).
793
+ val_bpb: float | None = None
794
+ # Eval batch: default to 4 on cloud GPUs (enough freed VRAM after optimizer
795
+ # clear), fall back to DEVICE_BATCH_SIZE//2 on tiny cards. Env-overridable.
796
+ _eval_B = int(os.environ.get("HYDRA_EVAL_BATCH",
797
+ str(max(1, DEVICE_BATCH_SIZE // 2) if DEVICE_BATCH_SIZE <= 8 else 4)))
798
+ # Eval tokens: default 1M (1,048,576) gives statistically meaningful BPB
799
+ # (256 forward passes at B=4, seq=1024). Env-overridable for fast/slow sweeps.
800
+ _eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(1048576)))
801
+ try:
802
+ # Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
803
+ # which leaves < 1GB for the eval forward — the driver can't satisfy
804
+ # the allocation. Free EVERY tensor we don't strictly need:
805
+ # - optimizer grads (set_to_none releases tensor)
806
+ # - optimizer.state (fp32 Muon NS workspace, AdamW moments — ~size-of-params each)
807
+ # - model internal caches (HTM subsample cache, SDR stash)
808
+ # After this, VRAM should be ~params only (bf16 ≈ 120MB at 60M params).
809
+ optimizer.zero_grad(set_to_none=True)
810
+ if hasattr(optimizer, 'state') and optimizer.state:
811
+ for p, st in list(optimizer.state.items()):
812
+ st.clear()
813
+ optimizer.state.clear()
814
+ for p in model.parameters():
815
+ if p.grad is not None:
816
+ p.grad = None
817
+ if hasattr(model, '_htm_cache'):
818
+ model._htm_cache = None
819
+ if hasattr(model, '_last_sdr'):
820
+ model._last_sdr = None
821
+ import gc as _gc
822
+ _gc.collect()
823
+ torch.cuda.empty_cache()
824
+ torch.cuda.synchronize()
825
+ try:
826
+ _free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
827
+ print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
828
+ except Exception:
829
+ pass
830
+ print(f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B}...", flush=True)
831
+ model.eval()
832
+ _orig = _prepare_mod.EVAL_TOKENS
833
+ _prepare_mod.EVAL_TOKENS = _eval_tokens
834
+ # Nemotron path reads HYDRA_STREAM_EVAL_TOKENS env var directly,
835
+ # not _prepare_mod.EVAL_TOKENS. Sync both so eval budget is
836
+ # respected regardless of which dataloader path is active.
837
+ _orig_stream = os.environ.get("HYDRA_STREAM_EVAL_TOKENS")
838
+ os.environ["HYDRA_STREAM_EVAL_TOKENS"] = str(_eval_tokens)
839
+ with autocast_ctx:
840
+ val_bpb = evaluate_bpb(model, tokenizer, _eval_B)
841
+ _prepare_mod.EVAL_TOKENS = _orig
842
+ if _orig_stream is not None:
843
+ os.environ["HYDRA_STREAM_EVAL_TOKENS"] = _orig_stream
844
+ else:
845
+ os.environ.pop("HYDRA_STREAM_EVAL_TOKENS", None)
846
+ val_ppl = 2 ** val_bpb
847
+ print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
848
+ except torch.cuda.OutOfMemoryError as e:
849
+ print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
850
+ torch.cuda.empty_cache()
851
+ except Exception as e:
852
+ import traceback as _tb
853
+ print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
854
+ _tb.print_exc()
855
+ try:
856
+ _free = torch.cuda.mem_get_info()[0] / 1024 / 1024
857
+ print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
858
+ except Exception:
859
+ pass
860
+
861
+ # Final ckpts with val_bpb filled in (if eval succeeded).
862
+ save_ckpt(
863
+ model, optimizer, config, step, total_training_time,
864
+ smooth_train_loss, bpt_ema, epoch, LATEST_CKPT,
865
+ val_bpb=val_bpb, blocking=True,
866
+ )
867
+ save_ckpt(
868
+ model, optimizer, config, step, total_training_time,
869
+ smooth_train_loss, bpt_ema, epoch, PRETRAIN_FINAL_CKPT,
870
+ val_bpb=val_bpb, blocking=True,
871
+ )
872
+
873
+ # Learnability #2: persist EMA weights alongside the raw checkpoint.
874
+ # latest_ema.pt contains ema_model.module (the Averaged params) so it
875
+ # can be loaded by evaluation / inference code that expects the same
876
+ # state_dict shape as the raw model.
877
+ if ema_model is not None:
878
+ try:
879
+ ema_ckpt_path = CACHE_DIR / "latest_ema.pt"
880
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
881
+ torch.save({
882
+ "model_state_dict": ema_model.module.state_dict(),
883
+ "config": asdict(config),
884
+ "step": step,
885
+ "epoch": epoch,
886
+ "train_seconds": total_training_time,
887
+ "val_bpb": val_bpb,
888
+ "ema_decay": EMA_DECAY,
889
+ }, str(ema_ckpt_path))
890
+ print(f"[EMA] saved {ema_ckpt_path} (step={step})", flush=True)
891
+ except Exception as _e:
892
+ print(f"[EMA] save failed: {_e}", flush=True)
893
+
894
+ run_factual_probes(model, tokenizer, device, autocast_ctx)
895
+
896
+ t_end = time.time()
897
+ startup_time = t_start_training - t_start
898
+ steady_state_mfu = (
899
+ 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10)
900
+ / total_training_time / GPU_BF16_PEAK_FLOPS
901
+ if total_training_time > 0 else 0
902
+ )
903
+ peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
904
+ metrics = model.get_secondary_metrics()
905
+
906
+ print("---")
907
+ print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
908
+ print(f"training_seconds: {total_training_time:.1f}")
909
+ print(f"total_seconds: {t_end - t_start:.1f}")
910
+ print(f"peak_vram_mb: {peak_vram_mb:.1f}")
911
+ print(f"mfu_percent: {steady_state_mfu:.2f}")
912
+ print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
913
+ print(f"num_steps: {step}")
914
+ print(f"num_params_M: {num_params / 1e6:.1f}")
915
+ print(f"n_layer: {N_LAYER}")
916
+ print(f"d_model: {D_MODEL}")
917
+ print(f"engram_hit_rate: {metrics.get('engram_hit_rate', 0.0):.4f}")
918
+ print(f"sdr_active_bits: {metrics.get('sdr_active_bits', 0):.1f}")
919
+ print(f"htm_anomaly: {metrics.get('htm_anomaly', 0):.4f}")
920
+
921
+ # Per-layer summary panel — only printed when diagnostics were active.
922
+ _layer_keys = sorted([k for k in metrics.keys() if k.startswith('layer_')])
923
+ if _layer_keys:
924
+ n_layers = len(model.blocks)
925
+ print("--- per-layer diagnostic panel ---")
926
+ for li in range(n_layers):
927
+ d_ratio = metrics.get(f'layer_{li}_delta_ratio', float('nan'))
928
+ out_n = metrics.get(f'layer_{li}_out_norm', float('nan'))
929
+ g_norm = metrics.get(f'layer_{li}_grad_norm', float('nan'))
930
+ eff_r = metrics.get(f'layer_{li}_eff_rank', float('nan'))
931
+ f_std = metrics.get(f'layer_{li}_feat_std', float('nan'))
932
+ print(
933
+ f"L{li:02d} delta_ratio={d_ratio:.4f} out_norm={out_n:.4f} "
934
+ f"grad_norm={g_norm:.3e} eff_rank={eff_r:.1f} feat_std={f_std:.4f}"
935
+ )
936
+
937
+ # Emit full metrics dictionary as JSON for sweep aggregation. Path from
938
+ # HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
939
+ # written (even without diagnostics) so the aggregator can compare runs.
940
+ _metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
941
+ try:
942
+ _dump = dict(metrics)
943
+ _dump.update({
944
+ 'val_bpb': (float(val_bpb) if val_bpb is not None else None),
945
+ 'val_ppl': (float(val_ppl) if val_ppl is not None else None),
946
+ 'n_layer': int(N_LAYER),
947
+ 'd_model': int(D_MODEL),
948
+ 'num_params_M': float(num_params / 1e6),
949
+ 'num_steps': int(step),
950
+ 'total_tokens_M': float(total_tokens / 1e6),
951
+ 'peak_vram_mb': float(peak_vram_mb),
952
+ 'training_seconds': float(total_training_time),
953
+ 'sdr_target_active': int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
954
+ })
955
+ Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
956
+ with open(_metrics_out, 'w') as _f:
957
+ json.dump(_dump, _f, indent=2, sort_keys=True)
958
+ print(f"[METRICS] wrote {_metrics_out}", flush=True)
959
+ # Also emit a single-line JSON to stdout so the sweep aggregator can
960
+ # scrape it from HF Jobs logs without pulling files out of the container.
961
+ print("[METRICS_JSON] " + json.dumps(_dump, sort_keys=True), flush=True)
962
+ except Exception as _e:
963
+ print(f"[METRICS] write failed: {_e}", flush=True)
964
+
965
+ run_factual_english(model, tokenizer, MAX_SEQ_LEN)
966
+ # startup_time is informative but not printed (preserve historical output)
967
+ _ = startup_time
overlay/kernels/cuda/decode_kernels.cu CHANGED
@@ -1,10 +1,10 @@
1
- /*
2
- * CuTe DSL decode kernels for Mamba-3 autoregressive generation.
3
- *
4
- * Phase 2: Optimized single-token SSM step for inference.
5
- * Phase 1: Not needed (training only, no generation).
6
- *
7
- * Fuses: input_proj + conv_step + ssm_step + output_proj
8
- * into a single kernel launch for minimal latency.
9
- */
10
- // Stub: Phase 2 implementation
 
1
+ /*
2
+ * CuTe DSL decode kernels for Mamba-3 autoregressive generation.
3
+ *
4
+ * Phase 2: Optimized single-token SSM step for inference.
5
+ * Phase 1: Not needed (training only, no generation).
6
+ *
7
+ * Fuses: input_proj + conv_step + ssm_step + output_proj
8
+ * into a single kernel launch for minimal latency.
9
+ */
10
+ // Stub: Phase 2 implementation
overlay/kernels/cuda/flashfftconv/LICENSE CHANGED
@@ -1,201 +1,201 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
overlay/kernels/cuda/flashfftconv/README.md CHANGED
@@ -1,57 +1,57 @@
1
- # flashfftconv (vendored)
2
-
3
- Vendored from https://github.com/HazyResearch/flash-fft-conv (Apache 2.0 license).
4
-
5
- **Upstream commit:** see `UPSTREAM_COMMIT`.
6
-
7
- ## What this is
8
-
9
- HazyResearch's Monarch-matrix-decomposition FFT convolution CUDA kernel. Provides a
10
- drop-in replacement for `torch.fft.rfft + complex-mult + irfft` that runs ~2-3x
11
- faster than cuFFT for the specific power-of-two lengths it supports (256, 512,
12
- 1024, 2048, 4096, 8192, ..., up to 4M).
13
-
14
- In HYDRA, we use it to accelerate `subsystems/hyena_pure.fftconv_ref`. The
15
- accelerated path is opt-in via `HYDRA_HYENA_FLASH_FFT=1`; default behavior is
16
- unchanged (pure PyTorch fallback).
17
-
18
- ## How to build
19
-
20
- The vendored tree contains:
21
- - `flashfftconv/` — pure-Python wrappers (imports `monarch_cuda` CUDA extension)
22
- - `csrc/` — CUDA source files and setup.py for the native extension
23
-
24
- Build instructions:
25
-
26
- ```bash
27
- cd /home/mikeb/work/feather/kernels/cuda/flashfftconv/csrc
28
-
29
- # Edit `csrc/setup.py` first: change the cc_flag line to match your GPU arch
30
- # (RTX 3060 = 8.6, A100 = 8.0, H100 = 9.0). Example for RTX 3060:
31
- # cc_flag = ['--generate-code=arch=compute_86,code=compute_86']
32
-
33
- # Build with the local CUDA toolchain (must match your torch.version.cuda):
34
- CUDA_HOME=/usr/local/cuda-12.1 .venv/bin/pip install -e .
35
- ```
36
-
37
- Then install the Python wrappers:
38
-
39
- ```bash
40
- cd /home/mikeb/work/feather/kernels/cuda/flashfftconv
41
- .venv/bin/pip install -e .
42
- ```
43
-
44
- ## Runtime usage
45
-
46
- Once installed, set `HYDRA_HYENA_FLASH_FFT=1` and training will use it.
47
- `subsystems/hyena_pure.fftconv_ref` auto-detects via `try: import flashfftconv`
48
- and falls back to pure PyTorch on import failure.
49
-
50
- ## Known caveats
51
-
52
- - Seqlen must be a power of 2 AND in the supported set: {256, 512, 1024, 2048,
53
- 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304}.
54
- For HYDRA, `fft_size = 2 * seq_len` → seq_len in {128, 256, 512, 1024, 2048, ...}.
55
- - dtype must be fp16 or bf16 (fp32 not supported).
56
- - GPU arch must be compiled into the extension (see setup.py cc_flag).
57
- - CUDA toolchain major.minor should match `torch.version.cuda` major (12.x ↔ 12.x).
 
1
+ # flashfftconv (vendored)
2
+
3
+ Vendored from https://github.com/HazyResearch/flash-fft-conv (Apache 2.0 license).
4
+
5
+ **Upstream commit:** see `UPSTREAM_COMMIT`.
6
+
7
+ ## What this is
8
+
9
+ HazyResearch's Monarch-matrix-decomposition FFT convolution CUDA kernel. Provides a
10
+ drop-in replacement for `torch.fft.rfft + complex-mult + irfft` that runs ~2-3x
11
+ faster than cuFFT for the specific power-of-two lengths it supports (256, 512,
12
+ 1024, 2048, 4096, 8192, ..., up to 4M).
13
+
14
+ In HYDRA, we use it to accelerate `subsystems/hyena_pure.fftconv_ref`. The
15
+ accelerated path is opt-in via `HYDRA_HYENA_FLASH_FFT=1`; default behavior is
16
+ unchanged (pure PyTorch fallback).
17
+
18
+ ## How to build
19
+
20
+ The vendored tree contains:
21
+ - `flashfftconv/` — pure-Python wrappers (imports `monarch_cuda` CUDA extension)
22
+ - `csrc/` — CUDA source files and setup.py for the native extension
23
+
24
+ Build instructions:
25
+
26
+ ```bash
27
+ cd /home/mikeb/work/feather/kernels/cuda/flashfftconv/csrc
28
+
29
+ # Edit `csrc/setup.py` first: change the cc_flag line to match your GPU arch
30
+ # (RTX 3060 = 8.6, A100 = 8.0, H100 = 9.0). Example for RTX 3060:
31
+ # cc_flag = ['--generate-code=arch=compute_86,code=compute_86']
32
+
33
+ # Build with the local CUDA toolchain (must match your torch.version.cuda):
34
+ CUDA_HOME=/usr/local/cuda-12.1 .venv/bin/pip install -e .
35
+ ```
36
+
37
+ Then install the Python wrappers:
38
+
39
+ ```bash
40
+ cd /home/mikeb/work/feather/kernels/cuda/flashfftconv
41
+ .venv/bin/pip install -e .
42
+ ```
43
+
44
+ ## Runtime usage
45
+
46
+ Once installed, set `HYDRA_HYENA_FLASH_FFT=1` and training will use it.
47
+ `subsystems/hyena_pure.fftconv_ref` auto-detects via `try: import flashfftconv`
48
+ and falls back to pure PyTorch on import failure.
49
+
50
+ ## Known caveats
51
+
52
+ - Seqlen must be a power of 2 AND in the supported set: {256, 512, 1024, 2048,
53
+ 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304}.
54
+ For HYDRA, `fft_size = 2 * seq_len` → seq_len in {128, 256, 512, 1024, 2048, ...}.
55
+ - dtype must be fp16 or bf16 (fp32 not supported).
56
+ - GPU arch must be compiled into the extension (see setup.py cc_flag).
57
+ - CUDA toolchain major.minor should match `torch.version.cuda` major (12.x ↔ 12.x).
overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT CHANGED
@@ -1 +1 @@
1
- b8771028717f46d5b22cbb8e12833f35033d621b
 
1
+ b8771028717f46d5b22cbb8e12833f35033d621b
overlay/kernels/cuda/flashfftconv/csrc/.gitignore CHANGED
@@ -1,10 +1,10 @@
1
- *.npy
2
- *.json
3
- *.png
4
-
5
- */*.npy
6
- */*.json
7
- */*.png
8
-
9
- *.DS_Store
10
  */*.DS_Store
 
1
+ *.npy
2
+ *.json
3
+ *.png
4
+
5
+ */*.npy
6
+ */*.json
7
+ */*.png
8
+
9
+ *.DS_Store
10
  */*.DS_Store
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h CHANGED
@@ -1,374 +1,374 @@
1
- // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
-
3
- #include <torch/extension.h>
4
-
5
- #include <vector>
6
-
7
- #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
8
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
- #define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
10
- #define CHECK_INPUT(x) \
11
- CHECK_CUDA(x); \
12
- CHECK_CONTIGUOUS(x); \
13
- CHECK_IS_HALF_OR_BFLOAT(x)
14
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
15
-
16
-
17
- std::vector<torch::Tensor> butterfly_cuda(
18
- torch::Tensor x,
19
- torch::Tensor d_f_T,
20
- torch::Tensor twiddle_factors_real,
21
- torch::Tensor twiddle_factors_imag,
22
- std::optional<at::Tensor> x_gate = std::nullopt
23
- );
24
-
25
-
26
- std::vector<torch::Tensor> butterfly_bf16_cuda(
27
- torch::Tensor x,
28
- torch::Tensor d_f_T_real,
29
- torch::Tensor d_f_T_imag,
30
- torch::Tensor twiddle_factors_real,
31
- torch::Tensor twiddle_factors_imag,
32
- std::optional<at::Tensor> out_gate = std::nullopt
33
- );
34
-
35
-
36
- std::vector<torch::Tensor> butterfly_padded_cuda(
37
- torch::Tensor x,
38
- torch::Tensor d_f_T,
39
- torch::Tensor twiddle_factors_real,
40
- torch::Tensor twiddle_factors_imag,
41
- int M,
42
- std::optional<at::Tensor> x_gate = std::nullopt
43
- );
44
-
45
-
46
- std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
47
- torch::Tensor x,
48
- torch::Tensor d_f_T_real,
49
- torch::Tensor d_f_T_imag,
50
- torch::Tensor twiddle_factors_real,
51
- torch::Tensor twiddle_factors_imag,
52
- int M,
53
- std::optional<at::Tensor> x_gate = std::nullopt
54
- );
55
-
56
- torch::Tensor butterfly_ifft_cuda(
57
- torch::Tensor x_real,
58
- torch::Tensor x_imag,
59
- torch::Tensor d_f_T,
60
- torch::Tensor twiddle_factors_real,
61
- torch::Tensor twiddle_factors_imag,
62
- std::optional<at::Tensor> out_gate = std::nullopt
63
- );
64
-
65
- torch::Tensor butterfly_ifft_bf16_cuda(
66
- torch::Tensor x_real,
67
- torch::Tensor x_imag,
68
- torch::Tensor d_f_real,
69
- torch::Tensor d_f_imag,
70
- torch::Tensor twiddle_factors_real,
71
- torch::Tensor twiddle_factors_imag,
72
- std::optional<at::Tensor> x_gate = std::nullopt
73
- );
74
-
75
- torch::Tensor butterfly_ifft_padded_cuda(
76
- torch::Tensor x_real,
77
- torch::Tensor x_imag,
78
- torch::Tensor d_f,
79
- torch::Tensor twiddle_factors_real,
80
- torch::Tensor twiddle_factors_imag,
81
- int N,
82
- std::optional<at::Tensor> out_gate = std::nullopt
83
- );
84
-
85
-
86
- torch::Tensor butterfly_ifft_padded_bf16_cuda(
87
- torch::Tensor x_real,
88
- torch::Tensor x_imag,
89
- torch::Tensor d_f_real,
90
- torch::Tensor d_f_imag,
91
- torch::Tensor twiddle_factors_real,
92
- torch::Tensor twiddle_factors_imag,
93
- int N,
94
- std::optional<at::Tensor> out_gate = std::nullopt
95
- );
96
-
97
- std::vector<torch::Tensor> butterfly(
98
- torch::Tensor x,
99
- torch::Tensor d_f_T,
100
- torch::Tensor twiddle_factors_real,
101
- torch::Tensor twiddle_factors_imag
102
- ){
103
- CHECK_INPUT(x);
104
- CHECK_INPUT(twiddle_factors_real);
105
- CHECK_INPUT(twiddle_factors_imag);
106
-
107
-
108
- return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
109
- }
110
-
111
- std::vector<torch::Tensor> butterfly_gated(
112
- torch::Tensor x,
113
- torch::Tensor d_f_T,
114
- torch::Tensor twiddle_factors_real,
115
- torch::Tensor twiddle_factors_imag,
116
- torch::Tensor x_gate
117
- ){
118
- CHECK_INPUT(x);
119
- CHECK_INPUT(twiddle_factors_real);
120
- CHECK_INPUT(twiddle_factors_imag);
121
-
122
- CHECK_INPUT(x_gate);
123
-
124
- return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, x_gate);
125
- }
126
-
127
- std::vector<torch::Tensor> butterfly_bf16(
128
- torch::Tensor x,
129
- torch::Tensor d_f_T_real,
130
- torch::Tensor d_f_T_imag,
131
- torch::Tensor twiddle_factors_real,
132
- torch::Tensor twiddle_factors_imag
133
- ){
134
- CHECK_INPUT(x);
135
- CHECK_INPUT(twiddle_factors_real);
136
- CHECK_INPUT(twiddle_factors_imag);
137
- CHECK_INPUT(d_f_T_real);
138
- CHECK_INPUT(d_f_T_imag);
139
-
140
-
141
- return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag);
142
- }
143
-
144
- std::vector<torch::Tensor> butterfly_gated_bf16(
145
- torch::Tensor x,
146
- torch::Tensor d_f_T_real,
147
- torch::Tensor d_f_T_imag,
148
- torch::Tensor twiddle_factors_real,
149
- torch::Tensor twiddle_factors_imag,
150
- torch::Tensor x_gate
151
- ){
152
- CHECK_INPUT(x);
153
- CHECK_INPUT(twiddle_factors_real);
154
- CHECK_INPUT(twiddle_factors_imag);
155
- CHECK_INPUT(d_f_T_real);
156
- CHECK_INPUT(d_f_T_imag);
157
- CHECK_INPUT(x_gate);
158
-
159
-
160
- return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, x_gate);
161
- }
162
-
163
- torch::Tensor butterfly_ifft(
164
- torch::Tensor x_real,
165
- torch::Tensor x_imag,
166
- torch::Tensor d_f_T,
167
- torch::Tensor twiddle_factors_real,
168
- torch::Tensor twiddle_factors_imag
169
- ){
170
- CHECK_INPUT(x_real);
171
- CHECK_INPUT(x_imag);
172
- CHECK_INPUT(twiddle_factors_real);
173
- CHECK_INPUT(twiddle_factors_imag);
174
-
175
- return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
176
- }
177
-
178
-
179
- torch::Tensor butterfly_ifft_gated(
180
- torch::Tensor x_real,
181
- torch::Tensor x_imag,
182
- torch::Tensor d_f_T,
183
- torch::Tensor twiddle_factors_real,
184
- torch::Tensor twiddle_factors_imag,
185
- torch::Tensor out_gate
186
- ){
187
- CHECK_INPUT(x_real);
188
- CHECK_INPUT(x_imag);
189
- CHECK_INPUT(twiddle_factors_real);
190
- CHECK_INPUT(twiddle_factors_imag);
191
- CHECK_INPUT(out_gate);
192
-
193
- return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag, out_gate);
194
- }
195
-
196
- torch::Tensor butterfly_ifft_bf16(
197
- torch::Tensor x_real,
198
- torch::Tensor x_imag,
199
- torch::Tensor d_f_real,
200
- torch::Tensor d_f_imag,
201
- torch::Tensor twiddle_factors_real,
202
- torch::Tensor twiddle_factors_imag
203
- ){
204
- CHECK_INPUT(x_real);
205
- CHECK_INPUT(x_imag);
206
- CHECK_INPUT(d_f_real);
207
- CHECK_INPUT(d_f_imag);
208
- CHECK_INPUT(twiddle_factors_real);
209
- CHECK_INPUT(twiddle_factors_imag);
210
-
211
-
212
- return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag);
213
- }
214
-
215
-
216
- torch::Tensor butterfly_ifft_gated_bf16(
217
- torch::Tensor x_real,
218
- torch::Tensor x_imag,
219
- torch::Tensor d_f_real,
220
- torch::Tensor d_f_imag,
221
- torch::Tensor twiddle_factors_real,
222
- torch::Tensor twiddle_factors_imag,
223
- torch::Tensor out_gate
224
- ){
225
- CHECK_INPUT(x_real);
226
- CHECK_INPUT(x_imag);
227
- CHECK_INPUT(d_f_real);
228
- CHECK_INPUT(d_f_imag);
229
- CHECK_INPUT(twiddle_factors_real);
230
- CHECK_INPUT(twiddle_factors_imag);
231
- CHECK_INPUT(out_gate);
232
-
233
- return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, out_gate);
234
- }
235
-
236
- std::vector<torch::Tensor> butterfly_padded(
237
- torch::Tensor x,
238
- torch::Tensor d_f_T,
239
- torch::Tensor twiddle_factors_real,
240
- torch::Tensor twiddle_factors_imag,
241
- int M
242
- ){
243
- CHECK_INPUT(x);
244
- CHECK_INPUT(twiddle_factors_real);
245
- CHECK_INPUT(twiddle_factors_imag);
246
-
247
-
248
- return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M);
249
- }
250
-
251
- std::vector<torch::Tensor> butterfly_padded_bf16(
252
- torch::Tensor x,
253
- torch::Tensor d_f_T_real,
254
- torch::Tensor d_f_T_imag,
255
- torch::Tensor twiddle_factors_real,
256
- torch::Tensor twiddle_factors_imag,
257
- int M
258
- ){
259
- CHECK_INPUT(x);
260
- CHECK_INPUT(twiddle_factors_real);
261
- CHECK_INPUT(twiddle_factors_imag);
262
-
263
-
264
- return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M);
265
- }
266
-
267
-
268
- std::vector<torch::Tensor> butterfly_padded_gated(
269
- torch::Tensor x,
270
- torch::Tensor d_f_T,
271
- torch::Tensor twiddle_factors_real,
272
- torch::Tensor twiddle_factors_imag,
273
- int M,
274
- torch::Tensor x_gate
275
- ){
276
- CHECK_INPUT(x);
277
- CHECK_INPUT(twiddle_factors_real);
278
- CHECK_INPUT(twiddle_factors_imag);
279
-
280
-
281
- return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
282
- }
283
-
284
- std::vector<torch::Tensor> butterfly_padded_gated_bf16(
285
- torch::Tensor x,
286
- torch::Tensor d_f_T_real,
287
- torch::Tensor d_f_T_imag,
288
- torch::Tensor twiddle_factors_real,
289
- torch::Tensor twiddle_factors_imag,
290
- int M,
291
- torch::Tensor x_gate
292
- ){
293
- CHECK_INPUT(x);
294
- CHECK_INPUT(d_f_T_real);
295
- CHECK_INPUT(d_f_T_imag);
296
- CHECK_INPUT(twiddle_factors_real);
297
- CHECK_INPUT(twiddle_factors_imag);
298
-
299
-
300
- return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
301
- }
302
-
303
- torch::Tensor butterfly_ifft_padded(
304
- torch::Tensor x_real,
305
- torch::Tensor x_imag,
306
- torch::Tensor d_f,
307
- torch::Tensor twiddle_factors_real,
308
- torch::Tensor twiddle_factors_imag,
309
- int N
310
- ){
311
- CHECK_INPUT(x_real);
312
- CHECK_INPUT(x_imag);
313
- CHECK_INPUT(twiddle_factors_real);
314
- CHECK_INPUT(twiddle_factors_imag);
315
-
316
- return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
317
- }
318
-
319
- torch::Tensor butterfly_ifft_padded_gated(
320
- torch::Tensor x_real,
321
- torch::Tensor x_imag,
322
- torch::Tensor d_f,
323
- torch::Tensor twiddle_factors_real,
324
- torch::Tensor twiddle_factors_imag,
325
- int N,
326
- torch::Tensor out_gate
327
- ){
328
- CHECK_INPUT(x_real);
329
- CHECK_INPUT(x_imag);
330
- CHECK_INPUT(twiddle_factors_real);
331
- CHECK_INPUT(twiddle_factors_imag);
332
-
333
- return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
334
- }
335
-
336
-
337
- torch::Tensor butterfly_ifft_padded_bf16(
338
- torch::Tensor x_real,
339
- torch::Tensor x_imag,
340
- torch::Tensor d_f_real,
341
- torch::Tensor d_f_imag,
342
- torch::Tensor twiddle_factors_real,
343
- torch::Tensor twiddle_factors_imag,
344
- int N
345
- ){
346
- CHECK_INPUT(x_real);
347
- CHECK_INPUT(x_imag);
348
- CHECK_INPUT(d_f_real);
349
- CHECK_INPUT(d_f_imag);
350
- CHECK_INPUT(twiddle_factors_real);
351
- CHECK_INPUT(twiddle_factors_imag);
352
-
353
- return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N);
354
- }
355
-
356
- torch::Tensor butterfly_ifft_padded_gated_bf16(
357
- torch::Tensor x_real,
358
- torch::Tensor x_imag,
359
- torch::Tensor d_f_real,
360
- torch::Tensor d_f_imag,
361
- torch::Tensor twiddle_factors_real,
362
- torch::Tensor twiddle_factors_imag,
363
- int N,
364
- torch::Tensor out_gate
365
- ){
366
- CHECK_INPUT(x_real);
367
- CHECK_INPUT(x_imag);
368
- CHECK_INPUT(d_f_real);
369
- CHECK_INPUT(d_f_imag);
370
- CHECK_INPUT(twiddle_factors_real);
371
- CHECK_INPUT(twiddle_factors_imag);
372
-
373
- return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
374
  }
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
10
+ #define CHECK_INPUT(x) \
11
+ CHECK_CUDA(x); \
12
+ CHECK_CONTIGUOUS(x); \
13
+ CHECK_IS_HALF_OR_BFLOAT(x)
14
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
15
+
16
+
17
+ std::vector<torch::Tensor> butterfly_cuda(
18
+ torch::Tensor x,
19
+ torch::Tensor d_f_T,
20
+ torch::Tensor twiddle_factors_real,
21
+ torch::Tensor twiddle_factors_imag,
22
+ std::optional<at::Tensor> x_gate = std::nullopt
23
+ );
24
+
25
+
26
+ std::vector<torch::Tensor> butterfly_bf16_cuda(
27
+ torch::Tensor x,
28
+ torch::Tensor d_f_T_real,
29
+ torch::Tensor d_f_T_imag,
30
+ torch::Tensor twiddle_factors_real,
31
+ torch::Tensor twiddle_factors_imag,
32
+ std::optional<at::Tensor> out_gate = std::nullopt
33
+ );
34
+
35
+
36
+ std::vector<torch::Tensor> butterfly_padded_cuda(
37
+ torch::Tensor x,
38
+ torch::Tensor d_f_T,
39
+ torch::Tensor twiddle_factors_real,
40
+ torch::Tensor twiddle_factors_imag,
41
+ int M,
42
+ std::optional<at::Tensor> x_gate = std::nullopt
43
+ );
44
+
45
+
46
+ std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
47
+ torch::Tensor x,
48
+ torch::Tensor d_f_T_real,
49
+ torch::Tensor d_f_T_imag,
50
+ torch::Tensor twiddle_factors_real,
51
+ torch::Tensor twiddle_factors_imag,
52
+ int M,
53
+ std::optional<at::Tensor> x_gate = std::nullopt
54
+ );
55
+
56
+ torch::Tensor butterfly_ifft_cuda(
57
+ torch::Tensor x_real,
58
+ torch::Tensor x_imag,
59
+ torch::Tensor d_f_T,
60
+ torch::Tensor twiddle_factors_real,
61
+ torch::Tensor twiddle_factors_imag,
62
+ std::optional<at::Tensor> out_gate = std::nullopt
63
+ );
64
+
65
+ torch::Tensor butterfly_ifft_bf16_cuda(
66
+ torch::Tensor x_real,
67
+ torch::Tensor x_imag,
68
+ torch::Tensor d_f_real,
69
+ torch::Tensor d_f_imag,
70
+ torch::Tensor twiddle_factors_real,
71
+ torch::Tensor twiddle_factors_imag,
72
+ std::optional<at::Tensor> x_gate = std::nullopt
73
+ );
74
+
75
+ torch::Tensor butterfly_ifft_padded_cuda(
76
+ torch::Tensor x_real,
77
+ torch::Tensor x_imag,
78
+ torch::Tensor d_f,
79
+ torch::Tensor twiddle_factors_real,
80
+ torch::Tensor twiddle_factors_imag,
81
+ int N,
82
+ std::optional<at::Tensor> out_gate = std::nullopt
83
+ );
84
+
85
+
86
+ torch::Tensor butterfly_ifft_padded_bf16_cuda(
87
+ torch::Tensor x_real,
88
+ torch::Tensor x_imag,
89
+ torch::Tensor d_f_real,
90
+ torch::Tensor d_f_imag,
91
+ torch::Tensor twiddle_factors_real,
92
+ torch::Tensor twiddle_factors_imag,
93
+ int N,
94
+ std::optional<at::Tensor> out_gate = std::nullopt
95
+ );
96
+
97
+ std::vector<torch::Tensor> butterfly(
98
+ torch::Tensor x,
99
+ torch::Tensor d_f_T,
100
+ torch::Tensor twiddle_factors_real,
101
+ torch::Tensor twiddle_factors_imag
102
+ ){
103
+ CHECK_INPUT(x);
104
+ CHECK_INPUT(twiddle_factors_real);
105
+ CHECK_INPUT(twiddle_factors_imag);
106
+
107
+
108
+ return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
109
+ }
110
+
111
+ std::vector<torch::Tensor> butterfly_gated(
112
+ torch::Tensor x,
113
+ torch::Tensor d_f_T,
114
+ torch::Tensor twiddle_factors_real,
115
+ torch::Tensor twiddle_factors_imag,
116
+ torch::Tensor x_gate
117
+ ){
118
+ CHECK_INPUT(x);
119
+ CHECK_INPUT(twiddle_factors_real);
120
+ CHECK_INPUT(twiddle_factors_imag);
121
+
122
+ CHECK_INPUT(x_gate);
123
+
124
+ return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, x_gate);
125
+ }
126
+
127
+ std::vector<torch::Tensor> butterfly_bf16(
128
+ torch::Tensor x,
129
+ torch::Tensor d_f_T_real,
130
+ torch::Tensor d_f_T_imag,
131
+ torch::Tensor twiddle_factors_real,
132
+ torch::Tensor twiddle_factors_imag
133
+ ){
134
+ CHECK_INPUT(x);
135
+ CHECK_INPUT(twiddle_factors_real);
136
+ CHECK_INPUT(twiddle_factors_imag);
137
+ CHECK_INPUT(d_f_T_real);
138
+ CHECK_INPUT(d_f_T_imag);
139
+
140
+
141
+ return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag);
142
+ }
143
+
144
+ std::vector<torch::Tensor> butterfly_gated_bf16(
145
+ torch::Tensor x,
146
+ torch::Tensor d_f_T_real,
147
+ torch::Tensor d_f_T_imag,
148
+ torch::Tensor twiddle_factors_real,
149
+ torch::Tensor twiddle_factors_imag,
150
+ torch::Tensor x_gate
151
+ ){
152
+ CHECK_INPUT(x);
153
+ CHECK_INPUT(twiddle_factors_real);
154
+ CHECK_INPUT(twiddle_factors_imag);
155
+ CHECK_INPUT(d_f_T_real);
156
+ CHECK_INPUT(d_f_T_imag);
157
+ CHECK_INPUT(x_gate);
158
+
159
+
160
+ return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, x_gate);
161
+ }
162
+
163
+ torch::Tensor butterfly_ifft(
164
+ torch::Tensor x_real,
165
+ torch::Tensor x_imag,
166
+ torch::Tensor d_f_T,
167
+ torch::Tensor twiddle_factors_real,
168
+ torch::Tensor twiddle_factors_imag
169
+ ){
170
+ CHECK_INPUT(x_real);
171
+ CHECK_INPUT(x_imag);
172
+ CHECK_INPUT(twiddle_factors_real);
173
+ CHECK_INPUT(twiddle_factors_imag);
174
+
175
+ return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
176
+ }
177
+
178
+
179
+ torch::Tensor butterfly_ifft_gated(
180
+ torch::Tensor x_real,
181
+ torch::Tensor x_imag,
182
+ torch::Tensor d_f_T,
183
+ torch::Tensor twiddle_factors_real,
184
+ torch::Tensor twiddle_factors_imag,
185
+ torch::Tensor out_gate
186
+ ){
187
+ CHECK_INPUT(x_real);
188
+ CHECK_INPUT(x_imag);
189
+ CHECK_INPUT(twiddle_factors_real);
190
+ CHECK_INPUT(twiddle_factors_imag);
191
+ CHECK_INPUT(out_gate);
192
+
193
+ return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag, out_gate);
194
+ }
195
+
196
+ torch::Tensor butterfly_ifft_bf16(
197
+ torch::Tensor x_real,
198
+ torch::Tensor x_imag,
199
+ torch::Tensor d_f_real,
200
+ torch::Tensor d_f_imag,
201
+ torch::Tensor twiddle_factors_real,
202
+ torch::Tensor twiddle_factors_imag
203
+ ){
204
+ CHECK_INPUT(x_real);
205
+ CHECK_INPUT(x_imag);
206
+ CHECK_INPUT(d_f_real);
207
+ CHECK_INPUT(d_f_imag);
208
+ CHECK_INPUT(twiddle_factors_real);
209
+ CHECK_INPUT(twiddle_factors_imag);
210
+
211
+
212
+ return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag);
213
+ }
214
+
215
+
216
+ torch::Tensor butterfly_ifft_gated_bf16(
217
+ torch::Tensor x_real,
218
+ torch::Tensor x_imag,
219
+ torch::Tensor d_f_real,
220
+ torch::Tensor d_f_imag,
221
+ torch::Tensor twiddle_factors_real,
222
+ torch::Tensor twiddle_factors_imag,
223
+ torch::Tensor out_gate
224
+ ){
225
+ CHECK_INPUT(x_real);
226
+ CHECK_INPUT(x_imag);
227
+ CHECK_INPUT(d_f_real);
228
+ CHECK_INPUT(d_f_imag);
229
+ CHECK_INPUT(twiddle_factors_real);
230
+ CHECK_INPUT(twiddle_factors_imag);
231
+ CHECK_INPUT(out_gate);
232
+
233
+ return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, out_gate);
234
+ }
235
+
236
+ std::vector<torch::Tensor> butterfly_padded(
237
+ torch::Tensor x,
238
+ torch::Tensor d_f_T,
239
+ torch::Tensor twiddle_factors_real,
240
+ torch::Tensor twiddle_factors_imag,
241
+ int M
242
+ ){
243
+ CHECK_INPUT(x);
244
+ CHECK_INPUT(twiddle_factors_real);
245
+ CHECK_INPUT(twiddle_factors_imag);
246
+
247
+
248
+ return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M);
249
+ }
250
+
251
+ std::vector<torch::Tensor> butterfly_padded_bf16(
252
+ torch::Tensor x,
253
+ torch::Tensor d_f_T_real,
254
+ torch::Tensor d_f_T_imag,
255
+ torch::Tensor twiddle_factors_real,
256
+ torch::Tensor twiddle_factors_imag,
257
+ int M
258
+ ){
259
+ CHECK_INPUT(x);
260
+ CHECK_INPUT(twiddle_factors_real);
261
+ CHECK_INPUT(twiddle_factors_imag);
262
+
263
+
264
+ return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M);
265
+ }
266
+
267
+
268
+ std::vector<torch::Tensor> butterfly_padded_gated(
269
+ torch::Tensor x,
270
+ torch::Tensor d_f_T,
271
+ torch::Tensor twiddle_factors_real,
272
+ torch::Tensor twiddle_factors_imag,
273
+ int M,
274
+ torch::Tensor x_gate
275
+ ){
276
+ CHECK_INPUT(x);
277
+ CHECK_INPUT(twiddle_factors_real);
278
+ CHECK_INPUT(twiddle_factors_imag);
279
+
280
+
281
+ return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
282
+ }
283
+
284
+ std::vector<torch::Tensor> butterfly_padded_gated_bf16(
285
+ torch::Tensor x,
286
+ torch::Tensor d_f_T_real,
287
+ torch::Tensor d_f_T_imag,
288
+ torch::Tensor twiddle_factors_real,
289
+ torch::Tensor twiddle_factors_imag,
290
+ int M,
291
+ torch::Tensor x_gate
292
+ ){
293
+ CHECK_INPUT(x);
294
+ CHECK_INPUT(d_f_T_real);
295
+ CHECK_INPUT(d_f_T_imag);
296
+ CHECK_INPUT(twiddle_factors_real);
297
+ CHECK_INPUT(twiddle_factors_imag);
298
+
299
+
300
+ return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
301
+ }
302
+
303
+ torch::Tensor butterfly_ifft_padded(
304
+ torch::Tensor x_real,
305
+ torch::Tensor x_imag,
306
+ torch::Tensor d_f,
307
+ torch::Tensor twiddle_factors_real,
308
+ torch::Tensor twiddle_factors_imag,
309
+ int N
310
+ ){
311
+ CHECK_INPUT(x_real);
312
+ CHECK_INPUT(x_imag);
313
+ CHECK_INPUT(twiddle_factors_real);
314
+ CHECK_INPUT(twiddle_factors_imag);
315
+
316
+ return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
317
+ }
318
+
319
+ torch::Tensor butterfly_ifft_padded_gated(
320
+ torch::Tensor x_real,
321
+ torch::Tensor x_imag,
322
+ torch::Tensor d_f,
323
+ torch::Tensor twiddle_factors_real,
324
+ torch::Tensor twiddle_factors_imag,
325
+ int N,
326
+ torch::Tensor out_gate
327
+ ){
328
+ CHECK_INPUT(x_real);
329
+ CHECK_INPUT(x_imag);
330
+ CHECK_INPUT(twiddle_factors_real);
331
+ CHECK_INPUT(twiddle_factors_imag);
332
+
333
+ return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
334
+ }
335
+
336
+
337
+ torch::Tensor butterfly_ifft_padded_bf16(
338
+ torch::Tensor x_real,
339
+ torch::Tensor x_imag,
340
+ torch::Tensor d_f_real,
341
+ torch::Tensor d_f_imag,
342
+ torch::Tensor twiddle_factors_real,
343
+ torch::Tensor twiddle_factors_imag,
344
+ int N
345
+ ){
346
+ CHECK_INPUT(x_real);
347
+ CHECK_INPUT(x_imag);
348
+ CHECK_INPUT(d_f_real);
349
+ CHECK_INPUT(d_f_imag);
350
+ CHECK_INPUT(twiddle_factors_real);
351
+ CHECK_INPUT(twiddle_factors_imag);
352
+
353
+ return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N);
354
+ }
355
+
356
+ torch::Tensor butterfly_ifft_padded_gated_bf16(
357
+ torch::Tensor x_real,
358
+ torch::Tensor x_imag,
359
+ torch::Tensor d_f_real,
360
+ torch::Tensor d_f_imag,
361
+ torch::Tensor twiddle_factors_real,
362
+ torch::Tensor twiddle_factors_imag,
363
+ int N,
364
+ torch::Tensor out_gate
365
+ ){
366
+ CHECK_INPUT(x_real);
367
+ CHECK_INPUT(x_imag);
368
+ CHECK_INPUT(d_f_real);
369
+ CHECK_INPUT(d_f_imag);
370
+ CHECK_INPUT(twiddle_factors_real);
371
+ CHECK_INPUT(twiddle_factors_imag);
372
+
373
+ return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
374
  }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu CHANGED
@@ -1,699 +1,699 @@
1
- // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
-
3
- #include <torch/extension.h>
4
-
5
- #include <vector>
6
- #include <stdio.h>
7
- #include <mma.h>
8
- #include <cuda_fp16.h>
9
- #include <cuda_bf16.h>
10
- #include "shared.h"
11
-
12
- using namespace nvcuda;
13
-
14
- __global__ void butterfly_cuda_kernel_64(
15
- const __half2 *__restrict__ x,
16
- const __half2 *__restrict__ x_gate,
17
- const complex_half_t *__restrict__ d_f,
18
- const __half2 *__restrict__ twiddle_factors_real,
19
- const __half2 *__restrict__ twiddle_factors_imag,
20
- __half2 *__restrict__ out_real,
21
- __half2 *__restrict__ out_imag,
22
- uint B,
23
- uint H,
24
- int N)
25
- {
26
- const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
27
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
28
- int idx;
29
- int shared_offset;
30
- const int B_Y = blockDim.y;
31
- const int n = N / B_Y;
32
-
33
-
34
- extern __shared__ half x_shared[];
35
- half *d_f_real = &x_shared[N * N];
36
- half *d_f_imag = &d_f_real[N * N];
37
- half *twiddles_real_shared = &d_f_imag[N * N];
38
- half *twiddles_imag_shared = &twiddles_real_shared[N * N];
39
- half *out_real_shared = &twiddles_imag_shared[N * N];
40
- half *out_imag_shared = &out_real_shared[N * N];
41
-
42
- // #pragma unroll
43
- for (int i = 0; i < n; i++)
44
- {
45
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
46
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
47
- reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
48
- reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
49
-
50
- // #pragma unroll
51
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
52
- d_f_real[shared_offset] = d_f[shared_offset].real();
53
- d_f_imag[shared_offset] = d_f[shared_offset].imag();
54
-
55
- d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
56
- d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
57
- }
58
-
59
- __half2 tmp_real, tmp_imag;
60
-
61
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
62
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
63
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
64
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
65
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[4][4];
66
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
67
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
68
-
69
- __syncthreads();
70
-
71
- for (int i = 0; i < 4; i++)
72
- {
73
- wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
74
- wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
75
- wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
76
- wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
77
- }
78
-
79
- for (int t = 0; t < 16; t++)
80
- {
81
-
82
- for (int i = 0; i < n; i++)
83
- {
84
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
85
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
86
- if(x_gate != nullptr){
87
- reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
88
- }else{
89
- reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
90
- }
91
- }
92
-
93
- __syncthreads();
94
-
95
- for (int i = 0; i < 4; i++)
96
- {
97
- for (int j = 0; j < 4; j++)
98
- {
99
- wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
100
- }
101
- }
102
-
103
- #pragma unroll
104
- for (int j = 0; j < 4; j++)
105
- {
106
- wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
107
-
108
- for (int k = 0; k < 4; k++)
109
- {
110
- wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
111
- }
112
- }
113
-
114
- #pragma unroll
115
-
116
- for (int j = 0; j < 4; j++)
117
- {
118
- wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
119
-
120
- for (int k = 0; k < 4; k++)
121
- {
122
- wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
123
- }
124
- }
125
-
126
- #pragma unroll
127
- for (int j = 0; j < 4; j++)
128
- {
129
- for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
130
- {
131
- tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
132
- tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
133
- reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
134
- reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
135
- }
136
-
137
- wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
138
- wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
139
- }
140
-
141
- __syncthreads();
142
-
143
- #pragma unroll
144
- for (int i = 0; i < n; i++)
145
- {
146
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
147
- out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
148
- out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
149
- }
150
-
151
- __syncthreads();
152
- }
153
- }
154
-
155
- __global__ void butterfly_cuda_kernel_32(
156
- const __half2 *__restrict__ x,
157
- const __half2 *__restrict__ x_gate,
158
- const complex_half_t *__restrict__ d_f,
159
- const __half2 *__restrict__ twiddle_factors_real,
160
- const __half2 *__restrict__ twiddle_factors_imag,
161
- __half2 *__restrict__ out_real,
162
- __half2 *__restrict__ out_imag,
163
- uint B,
164
- uint H,
165
- int N)
166
- {
167
- const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
168
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
169
- int idx;
170
-
171
- int shared_offset;
172
- const int B_Y = blockDim.y;
173
- const int n = N / B_Y;
174
-
175
-
176
- __shared__ half x_shared[32 * 64];
177
- __shared__ half d_f_real[32 * 32];
178
- __shared__ half d_f_imag[32 * 32];
179
- __shared__ half twiddles_real_shared[32 * 64];
180
- __shared__ half twiddles_imag_shared[32 * 64];
181
- __shared__ half out_real_shared[32 * 64];
182
- __shared__ half out_imag_shared[32 * 64];
183
-
184
- // #pragma unroll
185
- for (int i = 0; i < n; i++)
186
- {
187
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
188
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
189
- if(x_gate == nullptr){
190
- reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
191
- }else{
192
- reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
193
- }
194
- reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
195
- reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
196
-
197
- // #pragma unroll
198
- d_f_real[shared_offset] = d_f[shared_offset].real();
199
- d_f_imag[shared_offset] = d_f[shared_offset].imag();
200
- }
201
-
202
- __syncthreads();
203
-
204
- if (threadIdx.y < N / 16)
205
- {
206
- __half2 tmp_real, tmp_imag;
207
-
208
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
209
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
210
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
211
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
212
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2][2];
213
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
214
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
215
-
216
- int t = threadIdx.y * 32;
217
-
218
- for (int i = 0; i < 2; i++)
219
- {
220
- for (int j = 0; j < 2; j++)
221
- {
222
- wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
223
- wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
224
- wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
225
- wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
226
- wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
227
- }
228
- }
229
-
230
- #pragma unroll
231
- for (int i = 0; i < 2; i++)
232
- {
233
- for (int j = 0; j < 2; j++)
234
- {
235
- wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
236
-
237
- for (int k = 0; k < 2; k++)
238
- {
239
- wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
240
- }
241
- }
242
- }
243
-
244
- #pragma unroll
245
- for (int i = 0; i < 2; i++)
246
- {
247
- for (int j = 0; j < 2; j++)
248
- {
249
- wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
250
-
251
- for (int k = 0; k < 2; k++)
252
- {
253
- wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
254
- }
255
- }
256
- }
257
-
258
- #pragma unroll
259
- for (int i = 0; i < 2; i++)
260
- {
261
- for (int j = 0; j < 2; j++)
262
- {
263
- for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
264
- {
265
- tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
266
- tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
267
- reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
268
- reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
269
- }
270
-
271
- wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
272
- wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
273
- }
274
- }
275
- }
276
-
277
- __syncthreads();
278
-
279
- #pragma unroll
280
- for (int i = 0; i < n; i++)
281
- {
282
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
283
- out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
284
- out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
285
- }
286
- }
287
-
288
- __global__ void butterfly_cuda_kernel_128(
289
- const __half2 *__restrict__ x,
290
- const __half2 *__restrict__ x_gate,
291
- const complex_half_t *__restrict__ d_f,
292
- const __half2 *__restrict__ twiddle_factors_real,
293
- const __half2 *__restrict__ twiddle_factors_imag,
294
- __half2 *__restrict__ out_real,
295
- __half2 *__restrict__ out_imag,
296
- uint B,
297
- uint H,
298
- int N)
299
- {
300
- const int offset = blockIdx.y * H * 128 * 32 * gridDim.x * 2 + blockIdx.z * 16 * 128 * 32 * gridDim.x * 2 + blockIdx.x * 64 + threadIdx.x;
301
- const int tw_offset = blockIdx.x * 64 + threadIdx.x;
302
- int idx;
303
-
304
- int shared_offset;
305
- const int B_Y = blockDim.y;
306
- const int n = N / B_Y;
307
-
308
-
309
- extern __shared__ half shared_real[];
310
- half *shared_imag = &shared_real[128 * 128];
311
-
312
-
313
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
314
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
315
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
316
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
317
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[8][8];
318
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
319
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
320
-
321
- for (int i = 0; i < n; i++)
322
- {
323
- for(int j=0; j< 4; j++){
324
- shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
325
- shared_real[shared_offset] = d_f[shared_offset].real();
326
- shared_imag[shared_offset] = d_f[shared_offset].imag();
327
- }
328
- }
329
-
330
- __syncthreads();
331
-
332
-
333
- for (int i = 0; i < 8; i++){
334
- wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
335
- wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
336
- }
337
-
338
-
339
- __syncthreads();
340
-
341
-
342
-
343
- for (int i = 0; i < n; i++)
344
- {
345
- for(int j=0; j< 2; j++){
346
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
347
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
348
- reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
349
- reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
350
- }
351
- }
352
-
353
- __syncthreads();
354
-
355
-
356
- for (int i = 0; i < 8; i++){
357
- wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
358
- wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
359
- }
360
-
361
- __syncthreads();
362
-
363
-
364
- for(int t=0; t< 16; t++){
365
- for (int i = 0; i < n; i++)
366
- {
367
- for(int j=0; j< 2; j++){
368
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
369
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
370
- if(x_gate != nullptr){
371
- reinterpret_cast<__half2*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
372
- }else{
373
- reinterpret_cast<__half2*>(shared_real)[shared_offset] = x[offset + idx];
374
- }
375
-
376
- }
377
- }
378
-
379
-
380
- __syncthreads();
381
-
382
-
383
- for (int i = 0; i < 8; i++)
384
- {
385
- for (int j = 0; j < 8; j++)
386
- {
387
- wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
388
- }
389
- }
390
-
391
- __syncthreads();
392
-
393
- #pragma unroll
394
- for (int j = 0; j < 8; j++)
395
- {
396
- wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
397
-
398
- for (int k = 0; k < 8; k++)
399
- {
400
- wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
401
- }
402
- }
403
-
404
- #pragma unroll
405
-
406
- for (int j = 0; j < 8; j++)
407
- {
408
- wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
409
-
410
- for (int k = 0; k < 8; k++)
411
- {
412
- wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
413
- }
414
- }
415
-
416
- __half2 tmp_real, tmp_imag;
417
- #pragma unroll
418
- for (int j = 0; j < 8; j++)
419
- {
420
- for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
421
- {
422
- tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
423
- tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
424
- reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
425
- reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
426
- }
427
-
428
- wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
429
- wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
430
- }
431
-
432
- __syncthreads();
433
-
434
- #pragma unroll
435
- for (int i = 0; i < n; i++)
436
- {
437
- for(int j=0; j< 2; j++){
438
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
439
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
440
- out_real[offset + idx] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
441
- out_imag[offset + idx] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
442
- }
443
- }
444
-
445
- __syncthreads();
446
- }
447
- }
448
-
449
-
450
- __global__ void butterfly_cuda_kernel_16(
451
- const __half2 *__restrict__ x,
452
- const __half2 *__restrict__ x_gate,
453
- const complex_half_t *__restrict__ d_f,
454
- const __half2 *__restrict__ twiddle_factors_real,
455
- const __half2 *__restrict__ twiddle_factors_imag,
456
- __half2 *__restrict__ out_real,
457
- __half2 *__restrict__ out_imag,
458
- uint B,
459
- uint H,
460
- int N)
461
- {
462
- const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
463
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
464
- int idx;
465
-
466
- int shared_offset;
467
- const int B_Y = blockDim.y;
468
- const int n = N / B_Y;
469
-
470
-
471
- __shared__ half x_shared[16 * 64];
472
- __shared__ half d_f_real[16 * 16];
473
- __shared__ half d_f_imag[16 * 16];
474
- __shared__ half twiddles_real_shared[16 * 64];
475
- __shared__ half twiddles_imag_shared[16 * 64];
476
- __shared__ half out_real_shared[16 * 64];
477
- __shared__ half out_imag_shared[16 * 64];
478
-
479
- // #pragma unroll
480
- for (int i = 0; i < n; i++)
481
- {
482
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
483
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
484
-
485
- if(x_gate != NULL)
486
- reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
487
- else
488
- reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
489
- reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
490
- reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
491
-
492
- // #pragma unroll
493
-
494
- if(threadIdx.x < 16 ){
495
- shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
496
- d_f_real[shared_offset] = d_f[shared_offset].real();
497
- d_f_imag[shared_offset] = d_f[shared_offset].imag();
498
- }
499
- }
500
-
501
- __syncthreads();
502
-
503
- if (threadIdx.y < 4)
504
- {
505
- __half2 tmp_real, tmp_imag;
506
-
507
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
508
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
509
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
510
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
511
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
512
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
513
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
514
-
515
- wmma::load_matrix_sync(a_frag_real, d_f_real, N);
516
- wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
517
- wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
518
- wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
519
- wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
520
-
521
-
522
- wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
523
-
524
-
525
- wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
526
-
527
-
528
- wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
529
-
530
-
531
- wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
532
-
533
-
534
-
535
- for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
536
- {
537
- tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
538
- tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
539
- reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
540
- reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
541
- }
542
-
543
- wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
544
- wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
545
- }
546
-
547
- __syncthreads();
548
-
549
- #pragma unroll
550
- for (int i = 0; i < n; i++)
551
- {
552
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
553
- out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
554
- out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
555
- }
556
- }
557
-
558
-
559
- std::vector<torch::Tensor> butterfly_cuda(
560
- torch::Tensor x,
561
- torch::Tensor d_f,
562
- torch::Tensor twiddle_factors_real,
563
- torch::Tensor twiddle_factors_imag,
564
- std::optional<at::Tensor> x_gate = std::nullopt)
565
- {
566
-
567
- uint B = x.size(0);
568
- uint H = x.size(1);
569
- // uint m = x.size(1);
570
-
571
- // const int TILE_SIZE = 16;
572
- uint N = x.size(2);
573
- uint M = x.size(3);
574
- dim3 gridDim;
575
- dim3 blockDim;
576
-
577
- gridDim.y = B;
578
- gridDim.z = H;
579
-
580
- torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
581
- torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
582
-
583
- //set blockDims
584
- switch(N){
585
- case 128:
586
- blockDim.x = 32;
587
- blockDim.y = 8;
588
- break;
589
- default:
590
- blockDim.x = 32;
591
- blockDim.y = 4;
592
- break;
593
- }
594
-
595
- //set gridDim.x
596
- switch(N){
597
- case 128:
598
- switch (M){
599
- case 16384:
600
- gridDim.x = 128;
601
- break;
602
- case 8192:
603
- gridDim.x = 64;
604
- break;
605
- case 4096:
606
- gridDim.x = 32;
607
- break;
608
- default:
609
- gridDim.x = 256;
610
- break;
611
- }
612
- break;
613
- default:
614
- switch (M){
615
- case 16384:
616
- gridDim.x = 256;
617
- break;
618
- case 8192:
619
- gridDim.x = 128;
620
- break;
621
- case 4096:
622
- gridDim.x = 64;
623
- break;
624
- default:
625
- gridDim.x = 512;
626
- break;
627
- }
628
- break;
629
- }
630
-
631
- switch (N)
632
- {
633
- case 16:
634
- butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
635
- static_cast<__half2 *>(x.data_ptr()),
636
- x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
637
- static_cast<complex_half_t *>(d_f.data_ptr()),
638
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
639
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
640
- static_cast<__half2 *>(out_real.data_ptr()),
641
- static_cast<__half2 *>(out_imag.data_ptr()),
642
- B,
643
- H,
644
- N);
645
- break;
646
- case 32:
647
- butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
648
- static_cast<__half2 *>(x.data_ptr()),
649
- x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
650
- static_cast<complex_half_t *>(d_f.data_ptr()),
651
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
652
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
653
- static_cast<__half2 *>(out_real.data_ptr()),
654
- static_cast<__half2 *>(out_imag.data_ptr()),
655
- B,
656
- H,
657
- N);
658
- break;
659
-
660
- case 64:
661
- gridDim.z = H / 16;
662
- cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
663
-
664
- butterfly_cuda_kernel_64<<<gridDim, blockDim, 57344>>>(
665
- static_cast<__half2 *>(x.data_ptr()),
666
- x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
667
- static_cast<complex_half_t *>(d_f.data_ptr()),
668
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
669
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
670
- static_cast<__half2 *>(out_real.data_ptr()),
671
- static_cast<__half2 *>(out_imag.data_ptr()),
672
- B,
673
- H,
674
- N);
675
- break;
676
- case 128:
677
- gridDim.z = H / 16;
678
- cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
679
-
680
- butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
681
- static_cast<__half2 *>(x.data_ptr()),
682
- x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
683
- static_cast<complex_half_t *>(d_f.data_ptr()),
684
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
685
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
686
- static_cast<__half2 *>(out_real.data_ptr()),
687
- static_cast<__half2 *>(out_imag.data_ptr()),
688
- B,
689
- H,
690
- N);
691
- break;
692
-
693
- default:
694
- printf("Not yet implemented \n");
695
- break;
696
- }
697
-
698
- return {out_real, out_imag};
699
  }
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include "shared.h"
11
+
12
+ using namespace nvcuda;
13
+
14
+ __global__ void butterfly_cuda_kernel_64(
15
+ const __half2 *__restrict__ x,
16
+ const __half2 *__restrict__ x_gate,
17
+ const complex_half_t *__restrict__ d_f,
18
+ const __half2 *__restrict__ twiddle_factors_real,
19
+ const __half2 *__restrict__ twiddle_factors_imag,
20
+ __half2 *__restrict__ out_real,
21
+ __half2 *__restrict__ out_imag,
22
+ uint B,
23
+ uint H,
24
+ int N)
25
+ {
26
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
27
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
28
+ int idx;
29
+ int shared_offset;
30
+ const int B_Y = blockDim.y;
31
+ const int n = N / B_Y;
32
+
33
+
34
+ extern __shared__ half x_shared[];
35
+ half *d_f_real = &x_shared[N * N];
36
+ half *d_f_imag = &d_f_real[N * N];
37
+ half *twiddles_real_shared = &d_f_imag[N * N];
38
+ half *twiddles_imag_shared = &twiddles_real_shared[N * N];
39
+ half *out_real_shared = &twiddles_imag_shared[N * N];
40
+ half *out_imag_shared = &out_real_shared[N * N];
41
+
42
+ // #pragma unroll
43
+ for (int i = 0; i < n; i++)
44
+ {
45
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
46
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
47
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
48
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
49
+
50
+ // #pragma unroll
51
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
52
+ d_f_real[shared_offset] = d_f[shared_offset].real();
53
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
54
+
55
+ d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
56
+ d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
57
+ }
58
+
59
+ __half2 tmp_real, tmp_imag;
60
+
61
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
62
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
63
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
64
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
65
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[4][4];
66
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
67
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
68
+
69
+ __syncthreads();
70
+
71
+ for (int i = 0; i < 4; i++)
72
+ {
73
+ wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
74
+ wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
75
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
76
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
77
+ }
78
+
79
+ for (int t = 0; t < 16; t++)
80
+ {
81
+
82
+ for (int i = 0; i < n; i++)
83
+ {
84
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
85
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
86
+ if(x_gate != nullptr){
87
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
88
+ }else{
89
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
90
+ }
91
+ }
92
+
93
+ __syncthreads();
94
+
95
+ for (int i = 0; i < 4; i++)
96
+ {
97
+ for (int j = 0; j < 4; j++)
98
+ {
99
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
100
+ }
101
+ }
102
+
103
+ #pragma unroll
104
+ for (int j = 0; j < 4; j++)
105
+ {
106
+ wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
107
+
108
+ for (int k = 0; k < 4; k++)
109
+ {
110
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
111
+ }
112
+ }
113
+
114
+ #pragma unroll
115
+
116
+ for (int j = 0; j < 4; j++)
117
+ {
118
+ wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
119
+
120
+ for (int k = 0; k < 4; k++)
121
+ {
122
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
123
+ }
124
+ }
125
+
126
+ #pragma unroll
127
+ for (int j = 0; j < 4; j++)
128
+ {
129
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
130
+ {
131
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
132
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
133
+ reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
134
+ reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
135
+ }
136
+
137
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
138
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
139
+ }
140
+
141
+ __syncthreads();
142
+
143
+ #pragma unroll
144
+ for (int i = 0; i < n; i++)
145
+ {
146
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
147
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
148
+ out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
149
+ }
150
+
151
+ __syncthreads();
152
+ }
153
+ }
154
+
155
+ __global__ void butterfly_cuda_kernel_32(
156
+ const __half2 *__restrict__ x,
157
+ const __half2 *__restrict__ x_gate,
158
+ const complex_half_t *__restrict__ d_f,
159
+ const __half2 *__restrict__ twiddle_factors_real,
160
+ const __half2 *__restrict__ twiddle_factors_imag,
161
+ __half2 *__restrict__ out_real,
162
+ __half2 *__restrict__ out_imag,
163
+ uint B,
164
+ uint H,
165
+ int N)
166
+ {
167
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
168
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
169
+ int idx;
170
+
171
+ int shared_offset;
172
+ const int B_Y = blockDim.y;
173
+ const int n = N / B_Y;
174
+
175
+
176
+ __shared__ half x_shared[32 * 64];
177
+ __shared__ half d_f_real[32 * 32];
178
+ __shared__ half d_f_imag[32 * 32];
179
+ __shared__ half twiddles_real_shared[32 * 64];
180
+ __shared__ half twiddles_imag_shared[32 * 64];
181
+ __shared__ half out_real_shared[32 * 64];
182
+ __shared__ half out_imag_shared[32 * 64];
183
+
184
+ // #pragma unroll
185
+ for (int i = 0; i < n; i++)
186
+ {
187
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
188
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
189
+ if(x_gate == nullptr){
190
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
191
+ }else{
192
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
193
+ }
194
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
195
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
196
+
197
+ // #pragma unroll
198
+ d_f_real[shared_offset] = d_f[shared_offset].real();
199
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
200
+ }
201
+
202
+ __syncthreads();
203
+
204
+ if (threadIdx.y < N / 16)
205
+ {
206
+ __half2 tmp_real, tmp_imag;
207
+
208
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
209
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
210
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
211
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
212
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2][2];
213
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
214
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
215
+
216
+ int t = threadIdx.y * 32;
217
+
218
+ for (int i = 0; i < 2; i++)
219
+ {
220
+ for (int j = 0; j < 2; j++)
221
+ {
222
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
223
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
224
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
225
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
226
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
227
+ }
228
+ }
229
+
230
+ #pragma unroll
231
+ for (int i = 0; i < 2; i++)
232
+ {
233
+ for (int j = 0; j < 2; j++)
234
+ {
235
+ wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
236
+
237
+ for (int k = 0; k < 2; k++)
238
+ {
239
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
240
+ }
241
+ }
242
+ }
243
+
244
+ #pragma unroll
245
+ for (int i = 0; i < 2; i++)
246
+ {
247
+ for (int j = 0; j < 2; j++)
248
+ {
249
+ wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
250
+
251
+ for (int k = 0; k < 2; k++)
252
+ {
253
+ wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
254
+ }
255
+ }
256
+ }
257
+
258
+ #pragma unroll
259
+ for (int i = 0; i < 2; i++)
260
+ {
261
+ for (int j = 0; j < 2; j++)
262
+ {
263
+ for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
264
+ {
265
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
266
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
267
+ reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
268
+ reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
269
+ }
270
+
271
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
272
+ wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
273
+ }
274
+ }
275
+ }
276
+
277
+ __syncthreads();
278
+
279
+ #pragma unroll
280
+ for (int i = 0; i < n; i++)
281
+ {
282
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
283
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
284
+ out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
285
+ }
286
+ }
287
+
288
+ __global__ void butterfly_cuda_kernel_128(
289
+ const __half2 *__restrict__ x,
290
+ const __half2 *__restrict__ x_gate,
291
+ const complex_half_t *__restrict__ d_f,
292
+ const __half2 *__restrict__ twiddle_factors_real,
293
+ const __half2 *__restrict__ twiddle_factors_imag,
294
+ __half2 *__restrict__ out_real,
295
+ __half2 *__restrict__ out_imag,
296
+ uint B,
297
+ uint H,
298
+ int N)
299
+ {
300
+ const int offset = blockIdx.y * H * 128 * 32 * gridDim.x * 2 + blockIdx.z * 16 * 128 * 32 * gridDim.x * 2 + blockIdx.x * 64 + threadIdx.x;
301
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
302
+ int idx;
303
+
304
+ int shared_offset;
305
+ const int B_Y = blockDim.y;
306
+ const int n = N / B_Y;
307
+
308
+
309
+ extern __shared__ half shared_real[];
310
+ half *shared_imag = &shared_real[128 * 128];
311
+
312
+
313
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
314
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
315
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
316
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
317
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[8][8];
318
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
319
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
320
+
321
+ for (int i = 0; i < n; i++)
322
+ {
323
+ for(int j=0; j< 4; j++){
324
+ shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
325
+ shared_real[shared_offset] = d_f[shared_offset].real();
326
+ shared_imag[shared_offset] = d_f[shared_offset].imag();
327
+ }
328
+ }
329
+
330
+ __syncthreads();
331
+
332
+
333
+ for (int i = 0; i < 8; i++){
334
+ wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
335
+ wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
336
+ }
337
+
338
+
339
+ __syncthreads();
340
+
341
+
342
+
343
+ for (int i = 0; i < n; i++)
344
+ {
345
+ for(int j=0; j< 2; j++){
346
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
347
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
348
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
349
+ reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
350
+ }
351
+ }
352
+
353
+ __syncthreads();
354
+
355
+
356
+ for (int i = 0; i < 8; i++){
357
+ wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
358
+ wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
359
+ }
360
+
361
+ __syncthreads();
362
+
363
+
364
+ for(int t=0; t< 16; t++){
365
+ for (int i = 0; i < n; i++)
366
+ {
367
+ for(int j=0; j< 2; j++){
368
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
369
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
370
+ if(x_gate != nullptr){
371
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
372
+ }else{
373
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = x[offset + idx];
374
+ }
375
+
376
+ }
377
+ }
378
+
379
+
380
+ __syncthreads();
381
+
382
+
383
+ for (int i = 0; i < 8; i++)
384
+ {
385
+ for (int j = 0; j < 8; j++)
386
+ {
387
+ wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
388
+ }
389
+ }
390
+
391
+ __syncthreads();
392
+
393
+ #pragma unroll
394
+ for (int j = 0; j < 8; j++)
395
+ {
396
+ wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
397
+
398
+ for (int k = 0; k < 8; k++)
399
+ {
400
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
401
+ }
402
+ }
403
+
404
+ #pragma unroll
405
+
406
+ for (int j = 0; j < 8; j++)
407
+ {
408
+ wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
409
+
410
+ for (int k = 0; k < 8; k++)
411
+ {
412
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
413
+ }
414
+ }
415
+
416
+ __half2 tmp_real, tmp_imag;
417
+ #pragma unroll
418
+ for (int j = 0; j < 8; j++)
419
+ {
420
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
421
+ {
422
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
423
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
424
+ reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
425
+ reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
426
+ }
427
+
428
+ wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
429
+ wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
430
+ }
431
+
432
+ __syncthreads();
433
+
434
+ #pragma unroll
435
+ for (int i = 0; i < n; i++)
436
+ {
437
+ for(int j=0; j< 2; j++){
438
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
439
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
440
+ out_real[offset + idx] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
441
+ out_imag[offset + idx] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
442
+ }
443
+ }
444
+
445
+ __syncthreads();
446
+ }
447
+ }
448
+
449
+
450
+ __global__ void butterfly_cuda_kernel_16(
451
+ const __half2 *__restrict__ x,
452
+ const __half2 *__restrict__ x_gate,
453
+ const complex_half_t *__restrict__ d_f,
454
+ const __half2 *__restrict__ twiddle_factors_real,
455
+ const __half2 *__restrict__ twiddle_factors_imag,
456
+ __half2 *__restrict__ out_real,
457
+ __half2 *__restrict__ out_imag,
458
+ uint B,
459
+ uint H,
460
+ int N)
461
+ {
462
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
463
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
464
+ int idx;
465
+
466
+ int shared_offset;
467
+ const int B_Y = blockDim.y;
468
+ const int n = N / B_Y;
469
+
470
+
471
+ __shared__ half x_shared[16 * 64];
472
+ __shared__ half d_f_real[16 * 16];
473
+ __shared__ half d_f_imag[16 * 16];
474
+ __shared__ half twiddles_real_shared[16 * 64];
475
+ __shared__ half twiddles_imag_shared[16 * 64];
476
+ __shared__ half out_real_shared[16 * 64];
477
+ __shared__ half out_imag_shared[16 * 64];
478
+
479
+ // #pragma unroll
480
+ for (int i = 0; i < n; i++)
481
+ {
482
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
483
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
484
+
485
+ if(x_gate != NULL)
486
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
487
+ else
488
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
489
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
490
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
491
+
492
+ // #pragma unroll
493
+
494
+ if(threadIdx.x < 16 ){
495
+ shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
496
+ d_f_real[shared_offset] = d_f[shared_offset].real();
497
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
498
+ }
499
+ }
500
+
501
+ __syncthreads();
502
+
503
+ if (threadIdx.y < 4)
504
+ {
505
+ __half2 tmp_real, tmp_imag;
506
+
507
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
508
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
509
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
510
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
511
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
512
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
513
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
514
+
515
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
516
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
517
+ wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
518
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
519
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
520
+
521
+
522
+ wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
523
+
524
+
525
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
526
+
527
+
528
+ wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
529
+
530
+
531
+ wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
532
+
533
+
534
+
535
+ for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
536
+ {
537
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
538
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
539
+ reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
540
+ reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
541
+ }
542
+
543
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
544
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
545
+ }
546
+
547
+ __syncthreads();
548
+
549
+ #pragma unroll
550
+ for (int i = 0; i < n; i++)
551
+ {
552
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
553
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
554
+ out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
555
+ }
556
+ }
557
+
558
+
559
+ std::vector<torch::Tensor> butterfly_cuda(
560
+ torch::Tensor x,
561
+ torch::Tensor d_f,
562
+ torch::Tensor twiddle_factors_real,
563
+ torch::Tensor twiddle_factors_imag,
564
+ std::optional<at::Tensor> x_gate = std::nullopt)
565
+ {
566
+
567
+ uint B = x.size(0);
568
+ uint H = x.size(1);
569
+ // uint m = x.size(1);
570
+
571
+ // const int TILE_SIZE = 16;
572
+ uint N = x.size(2);
573
+ uint M = x.size(3);
574
+ dim3 gridDim;
575
+ dim3 blockDim;
576
+
577
+ gridDim.y = B;
578
+ gridDim.z = H;
579
+
580
+ torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
581
+ torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
582
+
583
+ //set blockDims
584
+ switch(N){
585
+ case 128:
586
+ blockDim.x = 32;
587
+ blockDim.y = 8;
588
+ break;
589
+ default:
590
+ blockDim.x = 32;
591
+ blockDim.y = 4;
592
+ break;
593
+ }
594
+
595
+ //set gridDim.x
596
+ switch(N){
597
+ case 128:
598
+ switch (M){
599
+ case 16384:
600
+ gridDim.x = 128;
601
+ break;
602
+ case 8192:
603
+ gridDim.x = 64;
604
+ break;
605
+ case 4096:
606
+ gridDim.x = 32;
607
+ break;
608
+ default:
609
+ gridDim.x = 256;
610
+ break;
611
+ }
612
+ break;
613
+ default:
614
+ switch (M){
615
+ case 16384:
616
+ gridDim.x = 256;
617
+ break;
618
+ case 8192:
619
+ gridDim.x = 128;
620
+ break;
621
+ case 4096:
622
+ gridDim.x = 64;
623
+ break;
624
+ default:
625
+ gridDim.x = 512;
626
+ break;
627
+ }
628
+ break;
629
+ }
630
+
631
+ switch (N)
632
+ {
633
+ case 16:
634
+ butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
635
+ static_cast<__half2 *>(x.data_ptr()),
636
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
637
+ static_cast<complex_half_t *>(d_f.data_ptr()),
638
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
639
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
640
+ static_cast<__half2 *>(out_real.data_ptr()),
641
+ static_cast<__half2 *>(out_imag.data_ptr()),
642
+ B,
643
+ H,
644
+ N);
645
+ break;
646
+ case 32:
647
+ butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
648
+ static_cast<__half2 *>(x.data_ptr()),
649
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
650
+ static_cast<complex_half_t *>(d_f.data_ptr()),
651
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
652
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
653
+ static_cast<__half2 *>(out_real.data_ptr()),
654
+ static_cast<__half2 *>(out_imag.data_ptr()),
655
+ B,
656
+ H,
657
+ N);
658
+ break;
659
+
660
+ case 64:
661
+ gridDim.z = H / 16;
662
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
663
+
664
+ butterfly_cuda_kernel_64<<<gridDim, blockDim, 57344>>>(
665
+ static_cast<__half2 *>(x.data_ptr()),
666
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
667
+ static_cast<complex_half_t *>(d_f.data_ptr()),
668
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
669
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
670
+ static_cast<__half2 *>(out_real.data_ptr()),
671
+ static_cast<__half2 *>(out_imag.data_ptr()),
672
+ B,
673
+ H,
674
+ N);
675
+ break;
676
+ case 128:
677
+ gridDim.z = H / 16;
678
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
679
+
680
+ butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
681
+ static_cast<__half2 *>(x.data_ptr()),
682
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
683
+ static_cast<complex_half_t *>(d_f.data_ptr()),
684
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
685
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
686
+ static_cast<__half2 *>(out_real.data_ptr()),
687
+ static_cast<__half2 *>(out_imag.data_ptr()),
688
+ B,
689
+ H,
690
+ N);
691
+ break;
692
+
693
+ default:
694
+ printf("Not yet implemented \n");
695
+ break;
696
+ }
697
+
698
+ return {out_real, out_imag};
699
  }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu CHANGED
@@ -1,725 +1,725 @@
1
- // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
-
3
- #include <torch/extension.h>
4
-
5
- #include <vector>
6
- #include <stdio.h>
7
- #include <mma.h>
8
- #include <cuda_runtime.h>
9
- #include <cuda_fp16.h>
10
- #include <cuda_bf16.h>
11
- #include "shared.h"
12
-
13
- using namespace nvcuda;
14
-
15
- __global__ void butterfly_cuda_kernel_64(
16
- const __nv_bfloat162 *__restrict__ x,
17
- const __nv_bfloat162 *__restrict__ x_gate,
18
- const __nv_bfloat162 *__restrict__ d_f_real,
19
- const __nv_bfloat162 *__restrict__ d_f_imag,
20
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
- __nv_bfloat162 *__restrict__ out_real,
23
- __nv_bfloat162 *__restrict__ out_imag,
24
- uint B,
25
- uint H,
26
- int N)
27
- {
28
- const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
29
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
30
- int idx;
31
- int shared_offset;
32
- const int B_Y = blockDim.y;
33
- const int n = N / B_Y;
34
-
35
-
36
- extern __shared__ __nv_bfloat16 x_shared[];
37
- __nv_bfloat16 *d_f_real_shared = &x_shared[N * N];
38
- __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
39
- __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
40
- __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
41
- float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
42
- float *out_imag_shared = &out_real_shared[N * N];
43
-
44
- // #pragma unroll
45
- for (int i = 0; i < n; i++)
46
- {
47
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
48
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
49
- reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
50
- reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
51
-
52
- // #pragma unroll
53
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
54
- reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
55
- reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
56
- }
57
-
58
- float2 tmp_real, tmp_imag;
59
-
60
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
61
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
62
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
63
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
64
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
65
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
66
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
67
-
68
- __syncthreads();
69
-
70
- for (int i = 0; i < 4; i++)
71
- {
72
- wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
73
- wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
74
- wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
75
- wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
76
- }
77
-
78
- for (int t = 0; t < 16; t++)
79
- {
80
-
81
- for (int i = 0; i < n; i++)
82
- {
83
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
84
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
85
- if(x_gate != nullptr){
86
- reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
87
- }else{
88
- reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
89
- }
90
- }
91
-
92
- __syncthreads();
93
-
94
- for (int i = 0; i < 4; i++)
95
- {
96
- for (int j = 0; j < 4; j++)
97
- {
98
- wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
99
- }
100
- }
101
-
102
- #pragma unroll
103
- for (int j = 0; j < 4; j++)
104
- {
105
- wmma::fill_fragment(acc_frag_real[j], 0.0f);
106
-
107
- for (int k = 0; k < 4; k++)
108
- {
109
- wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
110
- }
111
- }
112
-
113
- #pragma unroll
114
-
115
- for (int j = 0; j < 4; j++)
116
- {
117
- wmma::fill_fragment(acc_frag_imag[j], 0.0f);
118
-
119
- for (int k = 0; k < 4; k++)
120
- {
121
- wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
122
- }
123
- }
124
-
125
- #pragma unroll
126
- for (int j = 0; j < 4; j++)
127
- {
128
- for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
129
- {
130
- tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
131
- tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
132
-
133
- reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
134
- reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
135
- }
136
-
137
- wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
138
- wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
139
- }
140
-
141
- __syncthreads();
142
-
143
- #pragma unroll
144
- for (int i = 0; i < n; i++)
145
- {
146
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
147
- out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
148
- out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
149
- }
150
-
151
- __syncthreads();
152
- }
153
- }
154
-
155
- __global__ void butterfly_cuda_kernel_32(
156
- const __nv_bfloat162 *__restrict__ x,
157
- const __nv_bfloat162 *__restrict__ x_gate,
158
- const __nv_bfloat16 *__restrict__ d_f_real,
159
- const __nv_bfloat16 *__restrict__ d_f_imag,
160
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
161
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
162
- __nv_bfloat162 *__restrict__ out_real,
163
- __nv_bfloat162 *__restrict__ out_imag,
164
- uint B,
165
- uint H,
166
- int N)
167
- {
168
- const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
169
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
170
- int idx;
171
-
172
- int shared_offset;
173
- const int B_Y = blockDim.y;
174
- const int n = N / B_Y;
175
-
176
-
177
- __shared__ __nv_bfloat16 x_shared[32 * 64];
178
- __shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
179
- __shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
180
- __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
181
- __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
182
- __shared__ float out_real_shared[32 * 64];
183
- __shared__ float out_imag_shared[32 * 64];
184
-
185
- // #pragma unroll
186
- for (int i = 0; i < n; i++)
187
- {
188
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
189
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
190
- if(x_gate != nullptr){
191
- reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
192
- }else{
193
- reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
194
- }
195
- reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
196
- reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
197
-
198
- // #pragma unroll
199
- d_f_real_shared[shared_offset] = d_f_real[shared_offset];
200
- d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
201
- }
202
-
203
- __syncthreads();
204
-
205
- if (threadIdx.y < N / 16)
206
- {
207
- float2 tmp_real, tmp_imag;
208
-
209
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
210
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
211
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
212
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
213
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[2][2];
214
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
215
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
216
-
217
- int t = threadIdx.y * 32;
218
-
219
- for (int i = 0; i < 2; i++)
220
- {
221
- for (int j = 0; j < 2; j++)
222
- {
223
- wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
224
- wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
225
- wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
226
- wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
227
- wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
228
- }
229
- }
230
-
231
- #pragma unroll
232
- for (int i = 0; i < 2; i++)
233
- {
234
- for (int j = 0; j < 2; j++)
235
- {
236
- wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
237
-
238
- for (int k = 0; k < 2; k++)
239
- {
240
- wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
241
- }
242
- }
243
- }
244
-
245
- #pragma unroll
246
- for (int i = 0; i < 2; i++)
247
- {
248
- for (int j = 0; j < 2; j++)
249
- {
250
- wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
251
-
252
- for (int k = 0; k < 2; k++)
253
- {
254
- wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
255
- }
256
- }
257
- }
258
-
259
- #pragma unroll
260
- for (int i = 0; i < 2; i++)
261
- {
262
- for (int j = 0; j < 2; j++)
263
- {
264
- for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
265
- {
266
- tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
267
- tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
268
- reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
269
- reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
270
- }
271
- wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
272
- wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
273
- }
274
- }
275
- }
276
-
277
- __syncthreads();
278
-
279
- #pragma unroll
280
- for (int i = 0; i < n; i++)
281
- {
282
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
283
- out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
284
- out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
285
- }
286
- }
287
-
288
- __global__ void butterfly_cuda_kernel_128(
289
- const __nv_bfloat162 *__restrict__ x,
290
- const __nv_bfloat162 *__restrict__ x_gate,
291
- const __nv_bfloat162 *__restrict__ d_f_real,
292
- const __nv_bfloat162 *__restrict__ d_f_imag,
293
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
294
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
295
- __nv_bfloat162 *__restrict__ out_real,
296
- __nv_bfloat162 *__restrict__ out_imag,
297
- uint B,
298
- uint H,
299
- int N)
300
- {
301
- const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
302
- const int tw_offset = blockIdx.x * 64 + threadIdx.x;
303
- int idx;
304
-
305
- int shared_offset;
306
- const int B_Y = blockDim.y;
307
- const int n = N / B_Y;
308
-
309
-
310
- extern __shared__ __nv_bfloat16 shared_real[];
311
- __nv_bfloat16 *shared_imag = &shared_real[128 * 128];
312
-
313
-
314
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
315
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
316
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
317
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
318
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[8][8];
319
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
320
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
321
-
322
- for (int i = 0; i < n; i++)
323
- {
324
- for(int j=0; j< 2; j++){
325
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
326
- reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
327
- reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
328
- }
329
- }
330
-
331
- __syncthreads();
332
-
333
-
334
- for (int i = 0; i < 8; i++){
335
- wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
336
- wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
337
- }
338
-
339
-
340
- __syncthreads();
341
-
342
-
343
-
344
- for (int i = 0; i < n; i++)
345
- {
346
- for(int j=0; j< 2; j++){
347
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
348
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
349
- reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
350
- reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
351
- }
352
- }
353
-
354
- __syncthreads();
355
-
356
-
357
- for (int i = 0; i < 8; i++){
358
- wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
359
- wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
360
- }
361
-
362
- __syncthreads();
363
-
364
-
365
- for(int t=0; t< 16; t++){
366
- for (int i = 0; i < n; i++)
367
- {
368
- for(int j=0; j< 2; j++){
369
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
370
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
371
- if(x_gate != nullptr){
372
- reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
373
- }else{
374
- reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = x[offset + idx];
375
- }
376
- }
377
- }
378
-
379
-
380
- __syncthreads();
381
-
382
-
383
- for (int i = 0; i < 8; i++)
384
- {
385
- for (int j = 0; j < 8; j++)
386
- {
387
- wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
388
- }
389
- }
390
-
391
- __syncthreads();
392
-
393
- #pragma unroll
394
- for (int j = 0; j < 8; j++)
395
- {
396
- wmma::fill_fragment(acc_frag_real[j], 0.0f);
397
-
398
- for (int k = 0; k < 8; k++)
399
- {
400
- wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
401
- }
402
- }
403
-
404
- #pragma unroll
405
-
406
- for (int j = 0; j < 8; j++)
407
- {
408
- wmma::fill_fragment(acc_frag_imag[j], 0.0f);
409
-
410
- for (int k = 0; k < 8; k++)
411
- {
412
- wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
413
- }
414
- }
415
-
416
- float2 tmp_real, tmp_imag;
417
- #pragma unroll
418
- for (int j = 0; j < 8; j++)
419
- {
420
- for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
421
- {
422
- tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
423
- tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
424
-
425
- reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
426
- reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
427
- }
428
- }
429
-
430
- for (int j = 0; j < 8; j++)
431
- {
432
- wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
433
- }
434
-
435
- __syncthreads();
436
-
437
- #pragma unroll
438
- for (int i = 0; i < n; i++)
439
- {
440
- for(int j=0; j< 2; j++){
441
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
442
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
443
- out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
444
- }
445
- }
446
-
447
- __syncthreads();
448
-
449
-
450
- for (int j = 0; j < 8; j++)
451
- {
452
- wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
453
- }
454
-
455
- __syncthreads();
456
-
457
- #pragma unroll
458
- for (int i = 0; i < n; i++)
459
- {
460
- for(int j=0; j< 2; j++){
461
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
462
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
463
- out_imag[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
464
- }
465
- }
466
- }
467
- }
468
-
469
-
470
- __global__ void butterfly_cuda_kernel_16(
471
- const __nv_bfloat162 *__restrict__ x,
472
- const __nv_bfloat162 *__restrict__ x_gate,
473
- const __nv_bfloat16 *__restrict__ d_f_real,
474
- const __nv_bfloat16 *__restrict__ d_f_imag,
475
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
476
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
477
- __nv_bfloat162 *__restrict__ out_real,
478
- __nv_bfloat162 *__restrict__ out_imag,
479
- uint B,
480
- uint H,
481
- int N)
482
- {
483
- const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
484
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
485
- int idx;
486
-
487
- int shared_offset;
488
- const int B_Y = blockDim.y;
489
- const int n = N / B_Y;
490
-
491
-
492
- __shared__ __nv_bfloat16 x_shared[16 * 64];
493
- __shared__ __nv_bfloat16 d_f_real_shared[16 * 16];
494
- __shared__ __nv_bfloat16 d_f_imag_shared[16 * 16];
495
- __shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
496
- __shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
497
- __shared__ float out_real_shared[16 * 64];
498
- __shared__ float out_imag_shared[16 * 64];
499
-
500
- // #pragma unroll
501
- for (int i = 0; i < n; i++)
502
- {
503
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
504
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
505
- if(x_gate != nullptr){
506
- reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
507
- }else{
508
- reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
509
- }
510
- reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
511
- reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
512
-
513
- // #pragma unroll
514
- if(threadIdx.x < 16 ){
515
- shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
516
- d_f_real_shared[shared_offset] = d_f_real[shared_offset];
517
- d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
518
- }
519
- }
520
-
521
- __syncthreads();
522
-
523
- if (threadIdx.y < 4)
524
- {
525
- float2 tmp_real, tmp_imag;
526
-
527
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
528
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
529
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
530
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
531
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
532
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
533
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
534
-
535
- wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
536
- wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
537
- wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
538
- wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
539
- wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
540
-
541
-
542
-
543
- wmma::fill_fragment(acc_frag_real, 0.0f);
544
-
545
-
546
- wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
547
-
548
-
549
-
550
- wmma::fill_fragment(acc_frag_imag, 0.0f);
551
-
552
-
553
- wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
554
-
555
-
556
- #pragma unroll
557
- for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
558
- {
559
- tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
560
- tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
561
- reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
562
- reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
563
- }
564
- wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
565
- wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
566
-
567
- }
568
- __syncthreads();
569
-
570
- #pragma unroll
571
- for (int i = 0; i < n; i++)
572
- {
573
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
574
- out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
575
- out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
576
- }
577
- }
578
-
579
- std::vector<torch::Tensor> butterfly_bf16_cuda(
580
- torch::Tensor x,
581
- torch::Tensor d_f_real,
582
- torch::Tensor d_f_imag,
583
- torch::Tensor twiddle_factors_real,
584
- torch::Tensor twiddle_factors_imag,
585
- std::optional<at::Tensor> x_gate = std::nullopt
586
- )
587
- {
588
-
589
- uint B = x.size(0);
590
- uint H = x.size(1);
591
- // uint m = x.size(1);
592
-
593
- // const int TILE_SIZE = 16;
594
- uint N = x.size(2);
595
- uint M = x.size(3);
596
- dim3 gridDim;
597
- dim3 blockDim;
598
-
599
- gridDim.y = B;
600
- gridDim.z = H;
601
-
602
- torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
603
- torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
604
-
605
- //set blockDims
606
- switch(N){
607
- case 128:
608
- blockDim.x = 32;
609
- blockDim.y = 8;
610
- break;
611
- default:
612
- blockDim.x = 32;
613
- blockDim.y = 4;
614
- break;
615
- }
616
-
617
- //set gridDim.x
618
- switch(N){
619
- case 128:
620
- switch (M){
621
- case 16384:
622
- gridDim.x = 128;
623
- break;
624
- case 8192:
625
- gridDim.x = 64;
626
- break;
627
- case 4096:
628
- gridDim.x = 32;
629
- break;
630
- default:
631
- gridDim.x = 256;
632
- break;
633
- }
634
- break;
635
- default:
636
- switch (M){
637
- case 16384:
638
- gridDim.x = 256;
639
- break;
640
- case 8192:
641
- gridDim.x = 128;
642
- break;
643
- case 4096:
644
- gridDim.x = 64;
645
- break;
646
- default:
647
- gridDim.x = 512;
648
- break;
649
- }
650
- break;
651
- }
652
-
653
- switch (N)
654
- {
655
- case 16:
656
- butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
657
- static_cast<__nv_bfloat162 *>(x.data_ptr()),
658
- x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
659
- static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
660
- static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
661
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
662
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
663
- static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
664
- static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
665
- B,
666
- H,
667
- N);
668
- break;
669
- case 32:
670
- butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
671
- static_cast<__nv_bfloat162 *>(x.data_ptr()),
672
- x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
673
- static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
674
- static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
675
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
676
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
677
- static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
678
- static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
679
- B,
680
- H,
681
- N);
682
- break;
683
-
684
- case 64:
685
- gridDim.z = H / 16;
686
- cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
687
-
688
- butterfly_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
689
- static_cast<__nv_bfloat162 *>(x.data_ptr()),
690
- x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
691
- static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
692
- static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
693
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
694
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
695
- static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
696
- static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
697
- B,
698
- H,
699
- N);
700
- break;
701
- case 128:
702
- gridDim.z = H / 16;
703
- cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
704
-
705
- butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
706
- static_cast<__nv_bfloat162 *>(x.data_ptr()),
707
- x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
708
- static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
709
- static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
710
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
711
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
712
- static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
713
- static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
714
- B,
715
- H,
716
- N);
717
- break;
718
-
719
- default:
720
- printf("Not yet implemented \n");
721
- break;
722
- }
723
-
724
- return {out_real, out_imag};
725
  }
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_runtime.h>
9
+ #include <cuda_fp16.h>
10
+ #include <cuda_bf16.h>
11
+ #include "shared.h"
12
+
13
+ using namespace nvcuda;
14
+
15
+ __global__ void butterfly_cuda_kernel_64(
16
+ const __nv_bfloat162 *__restrict__ x,
17
+ const __nv_bfloat162 *__restrict__ x_gate,
18
+ const __nv_bfloat162 *__restrict__ d_f_real,
19
+ const __nv_bfloat162 *__restrict__ d_f_imag,
20
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
+ __nv_bfloat162 *__restrict__ out_real,
23
+ __nv_bfloat162 *__restrict__ out_imag,
24
+ uint B,
25
+ uint H,
26
+ int N)
27
+ {
28
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
29
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
30
+ int idx;
31
+ int shared_offset;
32
+ const int B_Y = blockDim.y;
33
+ const int n = N / B_Y;
34
+
35
+
36
+ extern __shared__ __nv_bfloat16 x_shared[];
37
+ __nv_bfloat16 *d_f_real_shared = &x_shared[N * N];
38
+ __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
39
+ __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
40
+ __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
41
+ float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
42
+ float *out_imag_shared = &out_real_shared[N * N];
43
+
44
+ // #pragma unroll
45
+ for (int i = 0; i < n; i++)
46
+ {
47
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
48
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
49
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
50
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
51
+
52
+ // #pragma unroll
53
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
54
+ reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
55
+ reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
56
+ }
57
+
58
+ float2 tmp_real, tmp_imag;
59
+
60
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
61
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
62
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
63
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
64
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
65
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
66
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
67
+
68
+ __syncthreads();
69
+
70
+ for (int i = 0; i < 4; i++)
71
+ {
72
+ wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
73
+ wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
74
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
75
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
76
+ }
77
+
78
+ for (int t = 0; t < 16; t++)
79
+ {
80
+
81
+ for (int i = 0; i < n; i++)
82
+ {
83
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
84
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
85
+ if(x_gate != nullptr){
86
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
87
+ }else{
88
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
89
+ }
90
+ }
91
+
92
+ __syncthreads();
93
+
94
+ for (int i = 0; i < 4; i++)
95
+ {
96
+ for (int j = 0; j < 4; j++)
97
+ {
98
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
99
+ }
100
+ }
101
+
102
+ #pragma unroll
103
+ for (int j = 0; j < 4; j++)
104
+ {
105
+ wmma::fill_fragment(acc_frag_real[j], 0.0f);
106
+
107
+ for (int k = 0; k < 4; k++)
108
+ {
109
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
110
+ }
111
+ }
112
+
113
+ #pragma unroll
114
+
115
+ for (int j = 0; j < 4; j++)
116
+ {
117
+ wmma::fill_fragment(acc_frag_imag[j], 0.0f);
118
+
119
+ for (int k = 0; k < 4; k++)
120
+ {
121
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
122
+ }
123
+ }
124
+
125
+ #pragma unroll
126
+ for (int j = 0; j < 4; j++)
127
+ {
128
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
129
+ {
130
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
131
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
132
+
133
+ reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
134
+ reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
135
+ }
136
+
137
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
138
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
139
+ }
140
+
141
+ __syncthreads();
142
+
143
+ #pragma unroll
144
+ for (int i = 0; i < n; i++)
145
+ {
146
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
147
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
148
+ out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
149
+ }
150
+
151
+ __syncthreads();
152
+ }
153
+ }
154
+
155
+ __global__ void butterfly_cuda_kernel_32(
156
+ const __nv_bfloat162 *__restrict__ x,
157
+ const __nv_bfloat162 *__restrict__ x_gate,
158
+ const __nv_bfloat16 *__restrict__ d_f_real,
159
+ const __nv_bfloat16 *__restrict__ d_f_imag,
160
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
161
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
162
+ __nv_bfloat162 *__restrict__ out_real,
163
+ __nv_bfloat162 *__restrict__ out_imag,
164
+ uint B,
165
+ uint H,
166
+ int N)
167
+ {
168
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
169
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
170
+ int idx;
171
+
172
+ int shared_offset;
173
+ const int B_Y = blockDim.y;
174
+ const int n = N / B_Y;
175
+
176
+
177
+ __shared__ __nv_bfloat16 x_shared[32 * 64];
178
+ __shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
179
+ __shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
180
+ __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
181
+ __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
182
+ __shared__ float out_real_shared[32 * 64];
183
+ __shared__ float out_imag_shared[32 * 64];
184
+
185
+ // #pragma unroll
186
+ for (int i = 0; i < n; i++)
187
+ {
188
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
189
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
190
+ if(x_gate != nullptr){
191
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
192
+ }else{
193
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
194
+ }
195
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
196
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
197
+
198
+ // #pragma unroll
199
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
200
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
201
+ }
202
+
203
+ __syncthreads();
204
+
205
+ if (threadIdx.y < N / 16)
206
+ {
207
+ float2 tmp_real, tmp_imag;
208
+
209
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
210
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
211
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
212
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
213
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[2][2];
214
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
215
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
216
+
217
+ int t = threadIdx.y * 32;
218
+
219
+ for (int i = 0; i < 2; i++)
220
+ {
221
+ for (int j = 0; j < 2; j++)
222
+ {
223
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
224
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
225
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
226
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
227
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
228
+ }
229
+ }
230
+
231
+ #pragma unroll
232
+ for (int i = 0; i < 2; i++)
233
+ {
234
+ for (int j = 0; j < 2; j++)
235
+ {
236
+ wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
237
+
238
+ for (int k = 0; k < 2; k++)
239
+ {
240
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
241
+ }
242
+ }
243
+ }
244
+
245
+ #pragma unroll
246
+ for (int i = 0; i < 2; i++)
247
+ {
248
+ for (int j = 0; j < 2; j++)
249
+ {
250
+ wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
251
+
252
+ for (int k = 0; k < 2; k++)
253
+ {
254
+ wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
255
+ }
256
+ }
257
+ }
258
+
259
+ #pragma unroll
260
+ for (int i = 0; i < 2; i++)
261
+ {
262
+ for (int j = 0; j < 2; j++)
263
+ {
264
+ for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
265
+ {
266
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
267
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
268
+ reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
269
+ reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
270
+ }
271
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
272
+ wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
273
+ }
274
+ }
275
+ }
276
+
277
+ __syncthreads();
278
+
279
+ #pragma unroll
280
+ for (int i = 0; i < n; i++)
281
+ {
282
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
283
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
284
+ out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
285
+ }
286
+ }
287
+
288
+ __global__ void butterfly_cuda_kernel_128(
289
+ const __nv_bfloat162 *__restrict__ x,
290
+ const __nv_bfloat162 *__restrict__ x_gate,
291
+ const __nv_bfloat162 *__restrict__ d_f_real,
292
+ const __nv_bfloat162 *__restrict__ d_f_imag,
293
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
294
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
295
+ __nv_bfloat162 *__restrict__ out_real,
296
+ __nv_bfloat162 *__restrict__ out_imag,
297
+ uint B,
298
+ uint H,
299
+ int N)
300
+ {
301
+ const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
302
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
303
+ int idx;
304
+
305
+ int shared_offset;
306
+ const int B_Y = blockDim.y;
307
+ const int n = N / B_Y;
308
+
309
+
310
+ extern __shared__ __nv_bfloat16 shared_real[];
311
+ __nv_bfloat16 *shared_imag = &shared_real[128 * 128];
312
+
313
+
314
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
315
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
316
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
317
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
318
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[8][8];
319
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
320
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
321
+
322
+ for (int i = 0; i < n; i++)
323
+ {
324
+ for(int j=0; j< 2; j++){
325
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
326
+ reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
327
+ reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
328
+ }
329
+ }
330
+
331
+ __syncthreads();
332
+
333
+
334
+ for (int i = 0; i < 8; i++){
335
+ wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
336
+ wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
337
+ }
338
+
339
+
340
+ __syncthreads();
341
+
342
+
343
+
344
+ for (int i = 0; i < n; i++)
345
+ {
346
+ for(int j=0; j< 2; j++){
347
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
348
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
349
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
350
+ reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
351
+ }
352
+ }
353
+
354
+ __syncthreads();
355
+
356
+
357
+ for (int i = 0; i < 8; i++){
358
+ wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
359
+ wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
360
+ }
361
+
362
+ __syncthreads();
363
+
364
+
365
+ for(int t=0; t< 16; t++){
366
+ for (int i = 0; i < n; i++)
367
+ {
368
+ for(int j=0; j< 2; j++){
369
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
370
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
371
+ if(x_gate != nullptr){
372
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
373
+ }else{
374
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = x[offset + idx];
375
+ }
376
+ }
377
+ }
378
+
379
+
380
+ __syncthreads();
381
+
382
+
383
+ for (int i = 0; i < 8; i++)
384
+ {
385
+ for (int j = 0; j < 8; j++)
386
+ {
387
+ wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
388
+ }
389
+ }
390
+
391
+ __syncthreads();
392
+
393
+ #pragma unroll
394
+ for (int j = 0; j < 8; j++)
395
+ {
396
+ wmma::fill_fragment(acc_frag_real[j], 0.0f);
397
+
398
+ for (int k = 0; k < 8; k++)
399
+ {
400
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
401
+ }
402
+ }
403
+
404
+ #pragma unroll
405
+
406
+ for (int j = 0; j < 8; j++)
407
+ {
408
+ wmma::fill_fragment(acc_frag_imag[j], 0.0f);
409
+
410
+ for (int k = 0; k < 8; k++)
411
+ {
412
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
413
+ }
414
+ }
415
+
416
+ float2 tmp_real, tmp_imag;
417
+ #pragma unroll
418
+ for (int j = 0; j < 8; j++)
419
+ {
420
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
421
+ {
422
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
423
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
424
+
425
+ reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
426
+ reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
427
+ }
428
+ }
429
+
430
+ for (int j = 0; j < 8; j++)
431
+ {
432
+ wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
433
+ }
434
+
435
+ __syncthreads();
436
+
437
+ #pragma unroll
438
+ for (int i = 0; i < n; i++)
439
+ {
440
+ for(int j=0; j< 2; j++){
441
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
442
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
443
+ out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
444
+ }
445
+ }
446
+
447
+ __syncthreads();
448
+
449
+
450
+ for (int j = 0; j < 8; j++)
451
+ {
452
+ wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
453
+ }
454
+
455
+ __syncthreads();
456
+
457
+ #pragma unroll
458
+ for (int i = 0; i < n; i++)
459
+ {
460
+ for(int j=0; j< 2; j++){
461
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
462
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
463
+ out_imag[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
464
+ }
465
+ }
466
+ }
467
+ }
468
+
469
+
470
+ __global__ void butterfly_cuda_kernel_16(
471
+ const __nv_bfloat162 *__restrict__ x,
472
+ const __nv_bfloat162 *__restrict__ x_gate,
473
+ const __nv_bfloat16 *__restrict__ d_f_real,
474
+ const __nv_bfloat16 *__restrict__ d_f_imag,
475
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
476
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
477
+ __nv_bfloat162 *__restrict__ out_real,
478
+ __nv_bfloat162 *__restrict__ out_imag,
479
+ uint B,
480
+ uint H,
481
+ int N)
482
+ {
483
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
484
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
485
+ int idx;
486
+
487
+ int shared_offset;
488
+ const int B_Y = blockDim.y;
489
+ const int n = N / B_Y;
490
+
491
+
492
+ __shared__ __nv_bfloat16 x_shared[16 * 64];
493
+ __shared__ __nv_bfloat16 d_f_real_shared[16 * 16];
494
+ __shared__ __nv_bfloat16 d_f_imag_shared[16 * 16];
495
+ __shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
496
+ __shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
497
+ __shared__ float out_real_shared[16 * 64];
498
+ __shared__ float out_imag_shared[16 * 64];
499
+
500
+ // #pragma unroll
501
+ for (int i = 0; i < n; i++)
502
+ {
503
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
504
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
505
+ if(x_gate != nullptr){
506
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
507
+ }else{
508
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
509
+ }
510
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
511
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
512
+
513
+ // #pragma unroll
514
+ if(threadIdx.x < 16 ){
515
+ shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
516
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
517
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
518
+ }
519
+ }
520
+
521
+ __syncthreads();
522
+
523
+ if (threadIdx.y < 4)
524
+ {
525
+ float2 tmp_real, tmp_imag;
526
+
527
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
528
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
529
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
530
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
531
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
532
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
533
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
534
+
535
+ wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
536
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
537
+ wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
538
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
539
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
540
+
541
+
542
+
543
+ wmma::fill_fragment(acc_frag_real, 0.0f);
544
+
545
+
546
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
547
+
548
+
549
+
550
+ wmma::fill_fragment(acc_frag_imag, 0.0f);
551
+
552
+
553
+ wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
554
+
555
+
556
+ #pragma unroll
557
+ for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
558
+ {
559
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
560
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
561
+ reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
562
+ reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
563
+ }
564
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
565
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
566
+
567
+ }
568
+ __syncthreads();
569
+
570
+ #pragma unroll
571
+ for (int i = 0; i < n; i++)
572
+ {
573
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
574
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
575
+ out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
576
+ }
577
+ }
578
+
579
+ std::vector<torch::Tensor> butterfly_bf16_cuda(
580
+ torch::Tensor x,
581
+ torch::Tensor d_f_real,
582
+ torch::Tensor d_f_imag,
583
+ torch::Tensor twiddle_factors_real,
584
+ torch::Tensor twiddle_factors_imag,
585
+ std::optional<at::Tensor> x_gate = std::nullopt
586
+ )
587
+ {
588
+
589
+ uint B = x.size(0);
590
+ uint H = x.size(1);
591
+ // uint m = x.size(1);
592
+
593
+ // const int TILE_SIZE = 16;
594
+ uint N = x.size(2);
595
+ uint M = x.size(3);
596
+ dim3 gridDim;
597
+ dim3 blockDim;
598
+
599
+ gridDim.y = B;
600
+ gridDim.z = H;
601
+
602
+ torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
603
+ torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
604
+
605
+ //set blockDims
606
+ switch(N){
607
+ case 128:
608
+ blockDim.x = 32;
609
+ blockDim.y = 8;
610
+ break;
611
+ default:
612
+ blockDim.x = 32;
613
+ blockDim.y = 4;
614
+ break;
615
+ }
616
+
617
+ //set gridDim.x
618
+ switch(N){
619
+ case 128:
620
+ switch (M){
621
+ case 16384:
622
+ gridDim.x = 128;
623
+ break;
624
+ case 8192:
625
+ gridDim.x = 64;
626
+ break;
627
+ case 4096:
628
+ gridDim.x = 32;
629
+ break;
630
+ default:
631
+ gridDim.x = 256;
632
+ break;
633
+ }
634
+ break;
635
+ default:
636
+ switch (M){
637
+ case 16384:
638
+ gridDim.x = 256;
639
+ break;
640
+ case 8192:
641
+ gridDim.x = 128;
642
+ break;
643
+ case 4096:
644
+ gridDim.x = 64;
645
+ break;
646
+ default:
647
+ gridDim.x = 512;
648
+ break;
649
+ }
650
+ break;
651
+ }
652
+
653
+ switch (N)
654
+ {
655
+ case 16:
656
+ butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
657
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
658
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
659
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
660
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
661
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
662
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
663
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
664
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
665
+ B,
666
+ H,
667
+ N);
668
+ break;
669
+ case 32:
670
+ butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
671
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
672
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
673
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
674
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
675
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
676
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
677
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
678
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
679
+ B,
680
+ H,
681
+ N);
682
+ break;
683
+
684
+ case 64:
685
+ gridDim.z = H / 16;
686
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
687
+
688
+ butterfly_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
689
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
690
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
691
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
692
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
693
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
694
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
695
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
696
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
697
+ B,
698
+ H,
699
+ N);
700
+ break;
701
+ case 128:
702
+ gridDim.z = H / 16;
703
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
704
+
705
+ butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
706
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
707
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
708
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
709
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
710
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
711
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
712
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
713
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
714
+ B,
715
+ H,
716
+ N);
717
+ break;
718
+
719
+ default:
720
+ printf("Not yet implemented \n");
721
+ break;
722
+ }
723
+
724
+ return {out_real, out_imag};
725
  }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu CHANGED
@@ -1,723 +1,723 @@
1
- // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
-
3
- #include <torch/extension.h>
4
-
5
- #include <vector>
6
- #include <stdio.h>
7
- #include <mma.h>
8
- #include <cuda_fp16.h>
9
- #include <cuda_bf16.h>
10
- #include "shared.h"
11
-
12
- using namespace nvcuda;
13
-
14
- __global__ void butterfly_ifft_cuda_kernel_64(
15
- const __half2 *__restrict__ x_real,
16
- const __half2 *__restrict__ x_imag,
17
- const complex_half_t *__restrict__ d_f,
18
- const __half2 *__restrict__ twiddle_factors_real,
19
- const __half2 *__restrict__ twiddle_factors_imag,
20
- __half2 *__restrict__ out_real,
21
- __half2 *__restrict__ out_gate,
22
- uint B,
23
- uint H,
24
- int N)
25
- {
26
- const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
27
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
28
- int idx;
29
- int shared_offset;
30
- const int B_Y = blockDim.y;
31
- const int n = N / B_Y;
32
-
33
- extern __shared__ half x_real_shared[];
34
- half *x_imag_shared = &x_real_shared[N * N];
35
- half *d_f_real = &x_imag_shared[N * N];
36
- half *d_f_imag = &d_f_real[N * N];
37
- half *twiddles_real_shared = &d_f_imag[N * N];
38
- half *twiddles_imag_shared = &twiddles_real_shared[N * N];
39
- half *out_real_shared = &twiddles_imag_shared[N * N];
40
-
41
- half tmp_real, tmp_imag;
42
-
43
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
44
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
45
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
46
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
47
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
48
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
49
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
50
-
51
- // #pragma unroll
52
- for (int i = 0; i < n; i++)
53
- {
54
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
55
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
56
- reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
57
- reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
58
-
59
- // #pragma unroll
60
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
61
- d_f_real[shared_offset] = d_f[shared_offset].real();
62
- d_f_imag[shared_offset] = d_f[shared_offset].imag();
63
-
64
- d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
65
- d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
66
- }
67
-
68
- __syncthreads();
69
-
70
- for (int i = 0; i < 4; i++)
71
- {
72
- #pragma unroll
73
- for (int j = 0; j < 4; j++)
74
- {
75
- wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
76
- wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
77
- }
78
- wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
79
- wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
80
- }
81
-
82
- for (int t = 0; t < 16; t++)
83
- {
84
-
85
- for (int i = 0; i < n; i++)
86
- {
87
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
88
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
89
- reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
90
- reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
91
- }
92
-
93
- __syncthreads();
94
-
95
- for (int i = 0; i < 4; i++)
96
- {
97
- wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
98
- wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
99
- }
100
-
101
- for (int j = 0; j < 4; j++)
102
- {
103
- for (int k = 0; k < tw_frag_real[j].num_elements; k++)
104
- {
105
- tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
106
- tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
107
- b_frag_real[j].x[k] = tmp_real;
108
- b_frag_imag[j].x[k] = tmp_imag;
109
- }
110
- }
111
-
112
- for (int i = 0; i < 4; i++)
113
- {
114
- wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
115
-
116
- // bd
117
- #pragma unroll
118
- for (int k = 0; k < 4; k++)
119
- {
120
- wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
121
- }
122
-
123
- for (int k = 0; k < acc_frag_real[i].num_elements; k++)
124
- {
125
- acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
126
- }
127
- }
128
-
129
- for (int i = 0; i < 4; i++)
130
- {
131
- // ac - bd
132
- #pragma unroll
133
- for (int k = 0; k < 4; k++)
134
- {
135
- wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
136
- }
137
- }
138
-
139
- #pragma unroll
140
- for (int i = 0; i < 4; i++)
141
- {
142
- wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
143
- }
144
-
145
- __syncthreads();
146
-
147
- #pragma unroll
148
- for (int i = 0; i < n; i++)
149
- {
150
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
151
- if(out_gate != nullptr){
152
- out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
153
- }
154
- else{
155
- out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
156
- }
157
- }
158
-
159
- __syncthreads();
160
- }
161
- }
162
-
163
- __global__ void butterfly_ifft_cuda_kernel_32(
164
- const __half2 *__restrict__ x_real,
165
- const __half2 *__restrict__ x_imag,
166
- const complex_half_t *__restrict__ d_f,
167
- const __half2 *__restrict__ twiddle_factors_real,
168
- const __half2 *__restrict__ twiddle_factors_imag,
169
- __half2 *__restrict__ out_real,
170
- __half2 *__restrict__ out_gate,
171
- uint B,
172
- uint H,
173
- int N)
174
- {
175
- const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
176
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
177
- int idx;
178
- int shared_offset;
179
- const int B_Y = blockDim.y;
180
- const int n = N / B_Y;
181
-
182
- __shared__ half x_real_shared[32 * 64];
183
- __shared__ half x_imag_shared[32 * 64];
184
- __shared__ half d_f_real[32 * 32];
185
- __shared__ half d_f_imag[32 * 32];
186
- __shared__ half twiddles_real_shared[32 * 64];
187
- __shared__ half twiddles_imag_shared[32 * 64];
188
- __shared__ half out_real_shared[32 * 64];
189
-
190
- // #pragma unroll
191
- for (int i = 0; i < n; i++)
192
- {
193
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
194
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
195
- reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
196
- reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
197
- reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
198
- reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
199
-
200
- // #pragma unroll
201
- d_f_real[shared_offset] = d_f[shared_offset].real();
202
- d_f_imag[shared_offset] = d_f[shared_offset].imag();
203
- }
204
-
205
- __syncthreads();
206
-
207
- if (threadIdx.y < N / 16)
208
- {
209
- half tmp_real, tmp_imag;
210
-
211
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
212
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
213
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
214
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
215
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
216
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
217
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
218
-
219
- int t = threadIdx.y * 32;
220
-
221
- for (int i = 0; i < 2; i++)
222
- {
223
- for (int j = 0; j < 2; j++)
224
- {
225
- wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
226
- wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
227
- wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
228
- wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
229
- wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
230
- wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
231
- }
232
- }
233
-
234
- for (int i = 0; i < 2; i++)
235
- {
236
- for (int j = 0; j < 2; j++)
237
- {
238
- for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
239
- {
240
- tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
241
- tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
242
- b_frag_real[i][j].x[k] = tmp_real;
243
- b_frag_imag[i][j].x[k] = tmp_imag;
244
- }
245
- }
246
- }
247
-
248
- for (int i = 0; i < 2; i++)
249
- {
250
- for (int j = 0; j < 2; j++)
251
- {
252
- wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
253
-
254
- // bd
255
- for (int k = 0; k < 2; k++)
256
- {
257
- wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
258
- }
259
-
260
- for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
261
- {
262
- acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
263
- }
264
- }
265
- }
266
-
267
- for (int i = 0; i < 2; i++)
268
- {
269
- for (int j = 0; j < 2; j++)
270
- {
271
- // ac - bd
272
- for (int k = 0; k < 2; k++)
273
- {
274
- wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
275
- }
276
- }
277
- }
278
-
279
- for (int i = 0; i < 2; i++)
280
- {
281
- for (int j = 0; j < 2; j++)
282
- {
283
- wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
284
- }
285
- }
286
- }
287
-
288
- __syncthreads();
289
-
290
- #pragma unroll
291
- for (int i = 0; i < n; i++)
292
- {
293
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
294
- if(out_gate != nullptr){
295
- out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
296
- }
297
- else{
298
- out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
299
- }
300
- }
301
- }
302
-
303
-
304
- __global__ void butterfly_ifft_cuda_kernel_128(
305
- const __half2 *__restrict__ x_real,
306
- const __half2 *__restrict__ x_imag,
307
- const complex_half_t *__restrict__ d_f,
308
- const __half2 *__restrict__ twiddle_factors_real,
309
- const __half2 *__restrict__ twiddle_factors_imag,
310
- __half2 *__restrict__ out_real,
311
- __half2 *__restrict__ out_gate,
312
- uint B,
313
- uint H,
314
- int N)
315
- {
316
- const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
317
- const int tw_offset = blockIdx.x * 64 + threadIdx.x;
318
- int idx;
319
- int shared_offset;
320
-
321
- const int B_Y = 8;
322
- const int n = 16;
323
-
324
- extern __shared__ half real_shared[];
325
- half *imag_shared = &real_shared[128 * 128];
326
- half *real_shared_2 = &imag_shared[128 * 128];
327
- half *imag_shared_2 = &real_shared_2[128 * 128];
328
-
329
- __half2 tmp_real, tmp_imag;
330
-
331
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[8][8];
332
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
333
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
334
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
335
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
336
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
337
-
338
- for (int i = 0; i < n; i++)
339
- {
340
- for(int j=0; j< 4; j++){
341
- shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
342
- real_shared_2[shared_offset] = d_f[shared_offset].real();
343
- imag_shared_2[shared_offset] = d_f[shared_offset].imag();
344
- }
345
- }
346
-
347
-
348
- __syncthreads();
349
-
350
- for (int i = 0; i < n; i++)
351
- {
352
- for(int j=0; j< 2; j++){
353
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
354
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
355
- reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
356
- reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
357
- }
358
- }
359
-
360
- __syncthreads();
361
-
362
-
363
- for (int i = 0; i < 8; i++){
364
- wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
365
- wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
366
- }
367
-
368
- __syncthreads();
369
-
370
- for (int t = 0; t < 16; t++)
371
- {
372
-
373
- for (int i = 0; i < n; i++)
374
- {
375
- for(int j=0; j< 2; j++){
376
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
377
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
378
- reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[offset + idx];
379
- reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[offset + idx];
380
- }
381
- }
382
-
383
- __syncthreads();
384
-
385
- for (int i = 0; i < 8; i++)
386
- {
387
- wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
388
- wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
389
- }
390
-
391
-
392
- for (int j = 0; j < 8; j++)
393
- {
394
- for (int k = 0; k < tw_frag_real[j].num_elements/2; k++)
395
- {
396
- tmp_real = __hsub2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]),
397
- __hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]));
398
- tmp_imag = __hadd2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]),
399
- __hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]));
400
- reinterpret_cast<__half2*>(b_frag_real[j].x)[k] = tmp_real;
401
- reinterpret_cast<__half2*>(b_frag_imag[j].x)[k] = tmp_imag;
402
- }
403
- }
404
-
405
- for (int i = 0; i < 8; i++){
406
- for (int j = 0; j < 8; j++){
407
- wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
408
- }
409
- }
410
-
411
- __syncthreads();
412
-
413
- for (int i = 0; i < 8; i++)
414
- {
415
- wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
416
-
417
- // bd
418
- #pragma unroll
419
- for (int k = 0; k < 8; k++)
420
- {
421
- wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
422
- }
423
-
424
- for (int k = 0; k < acc_frag_real[i].num_elements; k++)
425
- {
426
- acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
427
- }
428
- }
429
-
430
-
431
- for (int i = 0; i < 8; i++){
432
- for (int j = 0; j < 8; j++){
433
- wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
434
- }
435
- }
436
-
437
- __syncthreads();
438
-
439
- for (int i = 0; i < 8; i++)
440
- {
441
- // ac - bd
442
- #pragma unroll
443
- for (int k = 0; k < 8; k++)
444
- {
445
- wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
446
- }
447
- }
448
-
449
- #pragma unroll
450
- for (int i = 0; i < 8; i++)
451
- {
452
- wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
453
- }
454
-
455
- __syncthreads();
456
-
457
- #pragma unroll
458
- for (int i = 0; i < n; i++)
459
- {
460
- for(int j=0; j< 2; j++){
461
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
462
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
463
- if(out_gate != nullptr){
464
- out_real[offset + idx] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[offset + idx]);
465
- }
466
- else{
467
- out_real[offset + idx] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
468
- }
469
- }
470
- }
471
-
472
- __syncthreads();
473
- }
474
- }
475
-
476
- __global__ void butterfly_ifft_cuda_kernel_16(
477
- const __half2 *__restrict__ x_real,
478
- const __half2 *__restrict__ x_imag,
479
- const complex_half_t *__restrict__ d_f,
480
- const __half2 *__restrict__ twiddle_factors_real,
481
- const __half2 *__restrict__ twiddle_factors_imag,
482
- __half2 *__restrict__ out_real,
483
- __half2 *__restrict__ out_gate,
484
- uint B,
485
- uint H,
486
- int N)
487
- {
488
- const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
489
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
490
- int idx;
491
- int shared_offset;
492
- const int B_Y = blockDim.y;
493
- const int n = N / B_Y;
494
-
495
- __shared__ half x_real_shared[16 * 64];
496
- __shared__ half x_imag_shared[16 * 64];
497
- __shared__ half d_f_real[16 * 16];
498
- __shared__ half d_f_imag[16 * 16];
499
- __shared__ half twiddles_real_shared[16 * 64];
500
- __shared__ half twiddles_imag_shared[16 * 64];
501
- __shared__ half out_real_shared[16 * 64];
502
-
503
- // #pragma unroll
504
- for (int i = 0; i < n; i++)
505
- {
506
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
507
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
508
- reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
509
- reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
510
- reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
511
- reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
512
-
513
- if(threadIdx.x < 16 ){
514
- shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
515
- d_f_real[shared_offset] = d_f[shared_offset].real();
516
- d_f_imag[shared_offset] = d_f[shared_offset].imag();
517
- }
518
- }
519
-
520
- __syncthreads();
521
-
522
- //check if it is better to have one warp do all the multiplication or split between warps
523
- if (threadIdx.y < 4)
524
- {
525
- half tmp_real, tmp_imag;
526
-
527
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
528
- wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
529
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
530
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
531
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
532
- wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
533
- wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
534
-
535
- wmma::load_matrix_sync(a_frag_real, d_f_real, N);
536
- wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
537
- wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
538
- wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
539
- wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
540
- wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
541
-
542
-
543
-
544
- for (int k = 0; k < tw_frag_real.num_elements; k++)
545
- {
546
- tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
547
- tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
548
- b_frag_real.x[k] = tmp_real;
549
- b_frag_imag.x[k] = tmp_imag;
550
- }
551
-
552
-
553
- wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
554
-
555
- wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
556
-
557
- for(int k=0; k< acc_frag_real.num_elements; k++){
558
- acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
559
- }
560
-
561
-
562
- wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
563
-
564
- wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
565
-
566
- }
567
-
568
- __syncthreads();
569
-
570
- #pragma unroll
571
- for (int i = 0; i < n; i++)
572
- {
573
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
574
- if(out_gate != nullptr){
575
- out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
576
- }
577
- else{
578
- out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
579
- }
580
- }
581
- }
582
-
583
- torch::Tensor butterfly_ifft_cuda(
584
- torch::Tensor x_real,
585
- torch::Tensor x_imag,
586
- torch::Tensor d_f,
587
- torch::Tensor twiddle_factors_real,
588
- torch::Tensor twiddle_factors_imag,
589
- std::optional<at::Tensor> out_gate = std::nullopt)
590
- {
591
-
592
- uint B = x_real.size(0);
593
- uint H = x_real.size(1);
594
- // uint m = x.size(1);
595
-
596
- // const int TILE_SIZE = 16;
597
-
598
- dim3 gridDim;
599
- dim3 blockDim;
600
-
601
- uint N = x_real.size(2);
602
- uint M = x_real.size(3);
603
- gridDim.y = B;
604
-
605
- blockDim.x = 32;
606
- blockDim.y = 4;
607
-
608
- torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
609
- gridDim.z = H;
610
-
611
- //set blockDims
612
- switch(N){
613
- case 128:
614
- blockDim.x = 32;
615
- blockDim.y = 8;
616
- break;
617
- default:
618
- blockDim.x = 32;
619
- blockDim.y = 4;
620
- break;
621
- }
622
-
623
- //set gridDim.x
624
- switch(N){
625
- case 128:
626
- switch (M){
627
- case 16384:
628
- gridDim.x = 128;
629
- break;
630
- case 8192:
631
- gridDim.x = 64;
632
- break;
633
- case 4096:
634
- gridDim.x = 32;
635
- break;
636
- default:
637
- gridDim.x = 256;
638
- break;
639
- }
640
- break;
641
- default:
642
- switch (M){
643
- case 16384:
644
- gridDim.x = 256;
645
- break;
646
- case 8192:
647
- gridDim.x = 128;
648
- break;
649
- case 4096:
650
- gridDim.x = 64;
651
- break;
652
- default:
653
- gridDim.x = 512;
654
- break;
655
- }
656
- break;
657
- }
658
-
659
- switch (N)
660
- {
661
- case 16:
662
- butterfly_ifft_cuda_kernel_16<<<gridDim, blockDim>>>(
663
- static_cast<__half2 *>(x_real.data_ptr()),
664
- static_cast<__half2 *>(x_imag.data_ptr()),
665
- static_cast<complex_half_t *>(d_f.data_ptr()),
666
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
667
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
668
- static_cast<__half2 *>(out.data_ptr()),
669
- out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
670
- B,
671
- H,
672
- N);
673
- break;
674
- case 32:
675
- butterfly_ifft_cuda_kernel_32<<<gridDim, blockDim>>>(
676
- static_cast<__half2 *>(x_real.data_ptr()),
677
- static_cast<__half2 *>(x_imag.data_ptr()),
678
- static_cast<complex_half_t *>(d_f.data_ptr()),
679
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
680
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
681
- static_cast<__half2 *>(out.data_ptr()),
682
- out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
683
- B,
684
- H,
685
- N);
686
- break;
687
- case 64:
688
- gridDim.z = H / 16;
689
- cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
690
- butterfly_ifft_cuda_kernel_64<<<gridDim, blockDim, 8 * N * N * sizeof(half)>>>(
691
- static_cast<__half2 *>(x_real.data_ptr()),
692
- static_cast<__half2 *>(x_imag.data_ptr()),
693
- static_cast<complex_half_t *>(d_f.data_ptr()),
694
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
695
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
696
- static_cast<__half2 *>(out.data_ptr()),
697
- out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
698
- B,
699
- H,
700
- N);
701
- break;
702
-
703
- case 128:
704
- gridDim.z = H / 16;
705
- cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536*2);
706
- butterfly_ifft_cuda_kernel_128<<<gridDim, blockDim, 65536*2>>>(
707
- static_cast<__half2 *>(x_real.data_ptr()),
708
- static_cast<__half2 *>(x_imag.data_ptr()),
709
- static_cast<complex_half_t *>(d_f.data_ptr()),
710
- static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
711
- static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
712
- static_cast<__half2 *>(out.data_ptr()),
713
- out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
714
- B,
715
- H,
716
- N);
717
- break;
718
- default:
719
- printf("Not implemented\n");
720
- }
721
-
722
- return out;
723
- }
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include "shared.h"
11
+
12
+ using namespace nvcuda;
13
+
14
+ __global__ void butterfly_ifft_cuda_kernel_64(
15
+ const __half2 *__restrict__ x_real,
16
+ const __half2 *__restrict__ x_imag,
17
+ const complex_half_t *__restrict__ d_f,
18
+ const __half2 *__restrict__ twiddle_factors_real,
19
+ const __half2 *__restrict__ twiddle_factors_imag,
20
+ __half2 *__restrict__ out_real,
21
+ __half2 *__restrict__ out_gate,
22
+ uint B,
23
+ uint H,
24
+ int N)
25
+ {
26
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
27
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
28
+ int idx;
29
+ int shared_offset;
30
+ const int B_Y = blockDim.y;
31
+ const int n = N / B_Y;
32
+
33
+ extern __shared__ half x_real_shared[];
34
+ half *x_imag_shared = &x_real_shared[N * N];
35
+ half *d_f_real = &x_imag_shared[N * N];
36
+ half *d_f_imag = &d_f_real[N * N];
37
+ half *twiddles_real_shared = &d_f_imag[N * N];
38
+ half *twiddles_imag_shared = &twiddles_real_shared[N * N];
39
+ half *out_real_shared = &twiddles_imag_shared[N * N];
40
+
41
+ half tmp_real, tmp_imag;
42
+
43
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
44
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
45
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
46
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
47
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
48
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
49
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
50
+
51
+ // #pragma unroll
52
+ for (int i = 0; i < n; i++)
53
+ {
54
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
55
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
56
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
57
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
58
+
59
+ // #pragma unroll
60
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
61
+ d_f_real[shared_offset] = d_f[shared_offset].real();
62
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
63
+
64
+ d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
65
+ d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
66
+ }
67
+
68
+ __syncthreads();
69
+
70
+ for (int i = 0; i < 4; i++)
71
+ {
72
+ #pragma unroll
73
+ for (int j = 0; j < 4; j++)
74
+ {
75
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
76
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
77
+ }
78
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
79
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
80
+ }
81
+
82
+ for (int t = 0; t < 16; t++)
83
+ {
84
+
85
+ for (int i = 0; i < n; i++)
86
+ {
87
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
88
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
89
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
90
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
91
+ }
92
+
93
+ __syncthreads();
94
+
95
+ for (int i = 0; i < 4; i++)
96
+ {
97
+ wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
98
+ wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
99
+ }
100
+
101
+ for (int j = 0; j < 4; j++)
102
+ {
103
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
104
+ {
105
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
106
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
107
+ b_frag_real[j].x[k] = tmp_real;
108
+ b_frag_imag[j].x[k] = tmp_imag;
109
+ }
110
+ }
111
+
112
+ for (int i = 0; i < 4; i++)
113
+ {
114
+ wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
115
+
116
+ // bd
117
+ #pragma unroll
118
+ for (int k = 0; k < 4; k++)
119
+ {
120
+ wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
121
+ }
122
+
123
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
124
+ {
125
+ acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
126
+ }
127
+ }
128
+
129
+ for (int i = 0; i < 4; i++)
130
+ {
131
+ // ac - bd
132
+ #pragma unroll
133
+ for (int k = 0; k < 4; k++)
134
+ {
135
+ wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
136
+ }
137
+ }
138
+
139
+ #pragma unroll
140
+ for (int i = 0; i < 4; i++)
141
+ {
142
+ wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
143
+ }
144
+
145
+ __syncthreads();
146
+
147
+ #pragma unroll
148
+ for (int i = 0; i < n; i++)
149
+ {
150
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
151
+ if(out_gate != nullptr){
152
+ out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
153
+ }
154
+ else{
155
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
156
+ }
157
+ }
158
+
159
+ __syncthreads();
160
+ }
161
+ }
162
+
163
+ __global__ void butterfly_ifft_cuda_kernel_32(
164
+ const __half2 *__restrict__ x_real,
165
+ const __half2 *__restrict__ x_imag,
166
+ const complex_half_t *__restrict__ d_f,
167
+ const __half2 *__restrict__ twiddle_factors_real,
168
+ const __half2 *__restrict__ twiddle_factors_imag,
169
+ __half2 *__restrict__ out_real,
170
+ __half2 *__restrict__ out_gate,
171
+ uint B,
172
+ uint H,
173
+ int N)
174
+ {
175
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
176
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
177
+ int idx;
178
+ int shared_offset;
179
+ const int B_Y = blockDim.y;
180
+ const int n = N / B_Y;
181
+
182
+ __shared__ half x_real_shared[32 * 64];
183
+ __shared__ half x_imag_shared[32 * 64];
184
+ __shared__ half d_f_real[32 * 32];
185
+ __shared__ half d_f_imag[32 * 32];
186
+ __shared__ half twiddles_real_shared[32 * 64];
187
+ __shared__ half twiddles_imag_shared[32 * 64];
188
+ __shared__ half out_real_shared[32 * 64];
189
+
190
+ // #pragma unroll
191
+ for (int i = 0; i < n; i++)
192
+ {
193
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
194
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
195
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
196
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
197
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
198
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
199
+
200
+ // #pragma unroll
201
+ d_f_real[shared_offset] = d_f[shared_offset].real();
202
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
203
+ }
204
+
205
+ __syncthreads();
206
+
207
+ if (threadIdx.y < N / 16)
208
+ {
209
+ half tmp_real, tmp_imag;
210
+
211
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
212
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
213
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
214
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
215
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
216
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
217
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
218
+
219
+ int t = threadIdx.y * 32;
220
+
221
+ for (int i = 0; i < 2; i++)
222
+ {
223
+ for (int j = 0; j < 2; j++)
224
+ {
225
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
226
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
227
+ wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
228
+ wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
229
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
230
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
231
+ }
232
+ }
233
+
234
+ for (int i = 0; i < 2; i++)
235
+ {
236
+ for (int j = 0; j < 2; j++)
237
+ {
238
+ for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
239
+ {
240
+ tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
241
+ tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
242
+ b_frag_real[i][j].x[k] = tmp_real;
243
+ b_frag_imag[i][j].x[k] = tmp_imag;
244
+ }
245
+ }
246
+ }
247
+
248
+ for (int i = 0; i < 2; i++)
249
+ {
250
+ for (int j = 0; j < 2; j++)
251
+ {
252
+ wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
253
+
254
+ // bd
255
+ for (int k = 0; k < 2; k++)
256
+ {
257
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
258
+ }
259
+
260
+ for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
261
+ {
262
+ acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
263
+ }
264
+ }
265
+ }
266
+
267
+ for (int i = 0; i < 2; i++)
268
+ {
269
+ for (int j = 0; j < 2; j++)
270
+ {
271
+ // ac - bd
272
+ for (int k = 0; k < 2; k++)
273
+ {
274
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
275
+ }
276
+ }
277
+ }
278
+
279
+ for (int i = 0; i < 2; i++)
280
+ {
281
+ for (int j = 0; j < 2; j++)
282
+ {
283
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
284
+ }
285
+ }
286
+ }
287
+
288
+ __syncthreads();
289
+
290
+ #pragma unroll
291
+ for (int i = 0; i < n; i++)
292
+ {
293
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
294
+ if(out_gate != nullptr){
295
+ out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
296
+ }
297
+ else{
298
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
299
+ }
300
+ }
301
+ }
302
+
303
+
304
+ __global__ void butterfly_ifft_cuda_kernel_128(
305
+ const __half2 *__restrict__ x_real,
306
+ const __half2 *__restrict__ x_imag,
307
+ const complex_half_t *__restrict__ d_f,
308
+ const __half2 *__restrict__ twiddle_factors_real,
309
+ const __half2 *__restrict__ twiddle_factors_imag,
310
+ __half2 *__restrict__ out_real,
311
+ __half2 *__restrict__ out_gate,
312
+ uint B,
313
+ uint H,
314
+ int N)
315
+ {
316
+ const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
317
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
318
+ int idx;
319
+ int shared_offset;
320
+
321
+ const int B_Y = 8;
322
+ const int n = 16;
323
+
324
+ extern __shared__ half real_shared[];
325
+ half *imag_shared = &real_shared[128 * 128];
326
+ half *real_shared_2 = &imag_shared[128 * 128];
327
+ half *imag_shared_2 = &real_shared_2[128 * 128];
328
+
329
+ __half2 tmp_real, tmp_imag;
330
+
331
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[8][8];
332
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
333
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
334
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
335
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
336
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
337
+
338
+ for (int i = 0; i < n; i++)
339
+ {
340
+ for(int j=0; j< 4; j++){
341
+ shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
342
+ real_shared_2[shared_offset] = d_f[shared_offset].real();
343
+ imag_shared_2[shared_offset] = d_f[shared_offset].imag();
344
+ }
345
+ }
346
+
347
+
348
+ __syncthreads();
349
+
350
+ for (int i = 0; i < n; i++)
351
+ {
352
+ for(int j=0; j< 2; j++){
353
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
354
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
355
+ reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
356
+ reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
357
+ }
358
+ }
359
+
360
+ __syncthreads();
361
+
362
+
363
+ for (int i = 0; i < 8; i++){
364
+ wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
365
+ wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
366
+ }
367
+
368
+ __syncthreads();
369
+
370
+ for (int t = 0; t < 16; t++)
371
+ {
372
+
373
+ for (int i = 0; i < n; i++)
374
+ {
375
+ for(int j=0; j< 2; j++){
376
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
377
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
378
+ reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[offset + idx];
379
+ reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[offset + idx];
380
+ }
381
+ }
382
+
383
+ __syncthreads();
384
+
385
+ for (int i = 0; i < 8; i++)
386
+ {
387
+ wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
388
+ wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
389
+ }
390
+
391
+
392
+ for (int j = 0; j < 8; j++)
393
+ {
394
+ for (int k = 0; k < tw_frag_real[j].num_elements/2; k++)
395
+ {
396
+ tmp_real = __hsub2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]),
397
+ __hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]));
398
+ tmp_imag = __hadd2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]),
399
+ __hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]));
400
+ reinterpret_cast<__half2*>(b_frag_real[j].x)[k] = tmp_real;
401
+ reinterpret_cast<__half2*>(b_frag_imag[j].x)[k] = tmp_imag;
402
+ }
403
+ }
404
+
405
+ for (int i = 0; i < 8; i++){
406
+ for (int j = 0; j < 8; j++){
407
+ wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
408
+ }
409
+ }
410
+
411
+ __syncthreads();
412
+
413
+ for (int i = 0; i < 8; i++)
414
+ {
415
+ wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
416
+
417
+ // bd
418
+ #pragma unroll
419
+ for (int k = 0; k < 8; k++)
420
+ {
421
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
422
+ }
423
+
424
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
425
+ {
426
+ acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
427
+ }
428
+ }
429
+
430
+
431
+ for (int i = 0; i < 8; i++){
432
+ for (int j = 0; j < 8; j++){
433
+ wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
434
+ }
435
+ }
436
+
437
+ __syncthreads();
438
+
439
+ for (int i = 0; i < 8; i++)
440
+ {
441
+ // ac - bd
442
+ #pragma unroll
443
+ for (int k = 0; k < 8; k++)
444
+ {
445
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
446
+ }
447
+ }
448
+
449
+ #pragma unroll
450
+ for (int i = 0; i < 8; i++)
451
+ {
452
+ wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
453
+ }
454
+
455
+ __syncthreads();
456
+
457
+ #pragma unroll
458
+ for (int i = 0; i < n; i++)
459
+ {
460
+ for(int j=0; j< 2; j++){
461
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
462
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
463
+ if(out_gate != nullptr){
464
+ out_real[offset + idx] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[offset + idx]);
465
+ }
466
+ else{
467
+ out_real[offset + idx] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
468
+ }
469
+ }
470
+ }
471
+
472
+ __syncthreads();
473
+ }
474
+ }
475
+
476
+ __global__ void butterfly_ifft_cuda_kernel_16(
477
+ const __half2 *__restrict__ x_real,
478
+ const __half2 *__restrict__ x_imag,
479
+ const complex_half_t *__restrict__ d_f,
480
+ const __half2 *__restrict__ twiddle_factors_real,
481
+ const __half2 *__restrict__ twiddle_factors_imag,
482
+ __half2 *__restrict__ out_real,
483
+ __half2 *__restrict__ out_gate,
484
+ uint B,
485
+ uint H,
486
+ int N)
487
+ {
488
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
489
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
490
+ int idx;
491
+ int shared_offset;
492
+ const int B_Y = blockDim.y;
493
+ const int n = N / B_Y;
494
+
495
+ __shared__ half x_real_shared[16 * 64];
496
+ __shared__ half x_imag_shared[16 * 64];
497
+ __shared__ half d_f_real[16 * 16];
498
+ __shared__ half d_f_imag[16 * 16];
499
+ __shared__ half twiddles_real_shared[16 * 64];
500
+ __shared__ half twiddles_imag_shared[16 * 64];
501
+ __shared__ half out_real_shared[16 * 64];
502
+
503
+ // #pragma unroll
504
+ for (int i = 0; i < n; i++)
505
+ {
506
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
507
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
508
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
509
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
510
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
511
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
512
+
513
+ if(threadIdx.x < 16 ){
514
+ shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
515
+ d_f_real[shared_offset] = d_f[shared_offset].real();
516
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
517
+ }
518
+ }
519
+
520
+ __syncthreads();
521
+
522
+ //check if it is better to have one warp do all the multiplication or split between warps
523
+ if (threadIdx.y < 4)
524
+ {
525
+ half tmp_real, tmp_imag;
526
+
527
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
528
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
529
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
530
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
531
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
532
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
533
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
534
+
535
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
536
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
537
+ wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
538
+ wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
539
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
540
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
541
+
542
+
543
+
544
+ for (int k = 0; k < tw_frag_real.num_elements; k++)
545
+ {
546
+ tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
547
+ tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
548
+ b_frag_real.x[k] = tmp_real;
549
+ b_frag_imag.x[k] = tmp_imag;
550
+ }
551
+
552
+
553
+ wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
554
+
555
+ wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
556
+
557
+ for(int k=0; k< acc_frag_real.num_elements; k++){
558
+ acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
559
+ }
560
+
561
+
562
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
563
+
564
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
565
+
566
+ }
567
+
568
+ __syncthreads();
569
+
570
+ #pragma unroll
571
+ for (int i = 0; i < n; i++)
572
+ {
573
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
574
+ if(out_gate != nullptr){
575
+ out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
576
+ }
577
+ else{
578
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
579
+ }
580
+ }
581
+ }
582
+
583
+ torch::Tensor butterfly_ifft_cuda(
584
+ torch::Tensor x_real,
585
+ torch::Tensor x_imag,
586
+ torch::Tensor d_f,
587
+ torch::Tensor twiddle_factors_real,
588
+ torch::Tensor twiddle_factors_imag,
589
+ std::optional<at::Tensor> out_gate = std::nullopt)
590
+ {
591
+
592
+ uint B = x_real.size(0);
593
+ uint H = x_real.size(1);
594
+ // uint m = x.size(1);
595
+
596
+ // const int TILE_SIZE = 16;
597
+
598
+ dim3 gridDim;
599
+ dim3 blockDim;
600
+
601
+ uint N = x_real.size(2);
602
+ uint M = x_real.size(3);
603
+ gridDim.y = B;
604
+
605
+ blockDim.x = 32;
606
+ blockDim.y = 4;
607
+
608
+ torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
609
+ gridDim.z = H;
610
+
611
+ //set blockDims
612
+ switch(N){
613
+ case 128:
614
+ blockDim.x = 32;
615
+ blockDim.y = 8;
616
+ break;
617
+ default:
618
+ blockDim.x = 32;
619
+ blockDim.y = 4;
620
+ break;
621
+ }
622
+
623
+ //set gridDim.x
624
+ switch(N){
625
+ case 128:
626
+ switch (M){
627
+ case 16384:
628
+ gridDim.x = 128;
629
+ break;
630
+ case 8192:
631
+ gridDim.x = 64;
632
+ break;
633
+ case 4096:
634
+ gridDim.x = 32;
635
+ break;
636
+ default:
637
+ gridDim.x = 256;
638
+ break;
639
+ }
640
+ break;
641
+ default:
642
+ switch (M){
643
+ case 16384:
644
+ gridDim.x = 256;
645
+ break;
646
+ case 8192:
647
+ gridDim.x = 128;
648
+ break;
649
+ case 4096:
650
+ gridDim.x = 64;
651
+ break;
652
+ default:
653
+ gridDim.x = 512;
654
+ break;
655
+ }
656
+ break;
657
+ }
658
+
659
+ switch (N)
660
+ {
661
+ case 16:
662
+ butterfly_ifft_cuda_kernel_16<<<gridDim, blockDim>>>(
663
+ static_cast<__half2 *>(x_real.data_ptr()),
664
+ static_cast<__half2 *>(x_imag.data_ptr()),
665
+ static_cast<complex_half_t *>(d_f.data_ptr()),
666
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
667
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
668
+ static_cast<__half2 *>(out.data_ptr()),
669
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
670
+ B,
671
+ H,
672
+ N);
673
+ break;
674
+ case 32:
675
+ butterfly_ifft_cuda_kernel_32<<<gridDim, blockDim>>>(
676
+ static_cast<__half2 *>(x_real.data_ptr()),
677
+ static_cast<__half2 *>(x_imag.data_ptr()),
678
+ static_cast<complex_half_t *>(d_f.data_ptr()),
679
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
680
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
681
+ static_cast<__half2 *>(out.data_ptr()),
682
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
683
+ B,
684
+ H,
685
+ N);
686
+ break;
687
+ case 64:
688
+ gridDim.z = H / 16;
689
+ cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
690
+ butterfly_ifft_cuda_kernel_64<<<gridDim, blockDim, 8 * N * N * sizeof(half)>>>(
691
+ static_cast<__half2 *>(x_real.data_ptr()),
692
+ static_cast<__half2 *>(x_imag.data_ptr()),
693
+ static_cast<complex_half_t *>(d_f.data_ptr()),
694
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
695
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
696
+ static_cast<__half2 *>(out.data_ptr()),
697
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
698
+ B,
699
+ H,
700
+ N);
701
+ break;
702
+
703
+ case 128:
704
+ gridDim.z = H / 16;
705
+ cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536*2);
706
+ butterfly_ifft_cuda_kernel_128<<<gridDim, blockDim, 65536*2>>>(
707
+ static_cast<__half2 *>(x_real.data_ptr()),
708
+ static_cast<__half2 *>(x_imag.data_ptr()),
709
+ static_cast<complex_half_t *>(d_f.data_ptr()),
710
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
711
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
712
+ static_cast<__half2 *>(out.data_ptr()),
713
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
714
+ B,
715
+ H,
716
+ N);
717
+ break;
718
+ default:
719
+ printf("Not implemented\n");
720
+ }
721
+
722
+ return out;
723
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu CHANGED
@@ -1,705 +1,705 @@
1
- // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
-
3
- #include <torch/extension.h>
4
-
5
- #include <vector>
6
- #include <stdio.h>
7
- #include <mma.h>
8
- #include <cuda_fp16.h>
9
- #include <cuda_bf16.h>
10
- #include <cuda_runtime.h>
11
- #include "shared.h"
12
-
13
- using namespace nvcuda;
14
-
15
- __global__ void butterfly_ifft_bf16_cuda_kernel_64(
16
- const __nv_bfloat162 *__restrict__ x_real,
17
- const __nv_bfloat162 *__restrict__ x_imag,
18
- const __nv_bfloat162 *__restrict__ d_f_real,
19
- const __nv_bfloat162 *__restrict__ d_f_imag,
20
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
- __nv_bfloat162 *__restrict__ out_real,
23
- __nv_bfloat162 *__restrict__ out_gate,
24
- uint B,
25
- uint H,
26
- int N)
27
- {
28
- const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
29
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
30
- int idx;
31
- int shared_offset;
32
- const int B_Y = blockDim.y;
33
- const int n = N / B_Y;
34
-
35
- extern __shared__ __nv_bfloat16 x_real_shared[];
36
- __nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
37
- __nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
38
- __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
39
- __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
40
- __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
41
- float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
42
-
43
- __nv_bfloat16 tmp_real, tmp_imag;
44
-
45
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4][4];
46
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4][4];
47
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
48
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
49
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
50
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
51
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
52
-
53
- // #pragma unroll
54
- for (int i = 0; i < n; i++)
55
- {
56
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
57
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
58
- reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
59
- reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
60
-
61
- // #pragma unroll
62
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
63
- reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
64
- reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
65
- }
66
-
67
- __syncthreads();
68
-
69
- for (int i = 0; i < 4; i++)
70
- {
71
- #pragma unroll
72
- for (int j = 0; j < 4; j++)
73
- {
74
- wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
75
- wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
76
- }
77
- wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
78
- wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
79
- }
80
-
81
- for (int t = 0; t < 16; t++)
82
- {
83
-
84
- for (int i = 0; i < n; i++)
85
- {
86
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
87
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
88
- reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
89
- reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
90
- }
91
-
92
- __syncthreads();
93
-
94
- for (int i = 0; i < 4; i++)
95
- {
96
- wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
97
- wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
98
- }
99
-
100
- for (int j = 0; j < 4; j++)
101
- {
102
- for (int k = 0; k < tw_frag_real[j].num_elements; k++)
103
- {
104
- tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
105
- tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
106
- b_frag_real[j].x[k] = tmp_real;
107
- b_frag_imag[j].x[k] = tmp_imag;
108
- }
109
- }
110
-
111
- for (int i = 0; i < 4; i++)
112
- {
113
- wmma::fill_fragment(acc_frag_real[i], 0.0f);
114
-
115
- // bd
116
- #pragma unroll
117
- for (int k = 0; k < 4; k++)
118
- {
119
- wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
120
- }
121
-
122
- for (int k = 0; k < acc_frag_real[i].num_elements; k++)
123
- {
124
- acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
125
- }
126
- }
127
-
128
- for (int i = 0; i < 4; i++)
129
- {
130
- // ac - bd
131
- #pragma unroll
132
- for (int k = 0; k < 4; k++)
133
- {
134
- wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
135
- }
136
- }
137
-
138
- #pragma unroll
139
- for (int i = 0; i < 4; i++)
140
- {
141
- wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
142
- }
143
-
144
- __syncthreads();
145
-
146
- #pragma unroll
147
- for (int i = 0; i < n; i++)
148
- {
149
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
150
- if(out_gate != nullptr){
151
- out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]); ;
152
- }else{
153
- out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
154
- }
155
- }
156
-
157
- __syncthreads();
158
- }
159
- }
160
-
161
- __global__ void butterfly_ifft_bf16_cuda_kernel_32(
162
- const __nv_bfloat162 *__restrict__ x_real,
163
- const __nv_bfloat162 *__restrict__ x_imag,
164
- const __nv_bfloat16 *__restrict__ d_f_real,
165
- const __nv_bfloat16 *__restrict__ d_f_imag,
166
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
167
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
168
- __nv_bfloat162 *__restrict__ out_real,
169
- __nv_bfloat162 *__restrict__ out_gate,
170
- uint B,
171
- uint H,
172
- int N)
173
- {
174
- const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
175
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
176
- int idx;
177
- int shared_offset;
178
- const int B_Y = blockDim.y;
179
- const int n = N / B_Y;
180
-
181
- __shared__ __nv_bfloat16 x_real_shared[32 * 64];
182
- __shared__ __nv_bfloat16 x_imag_shared[32 * 64];
183
- __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
184
- __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
185
- __shared__ float out_real_shared[32 * 64];
186
-
187
- // #pragma unroll
188
- for (int i = 0; i < n; i++)
189
- {
190
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
191
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
192
- reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
193
- reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
194
- reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
195
- reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
196
- }
197
-
198
- __syncthreads();
199
-
200
- if (threadIdx.y < N / 16)
201
- {
202
- __nv_bfloat16 tmp_real, tmp_imag;
203
-
204
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
205
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
206
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
207
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
208
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
209
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
210
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
211
-
212
- int t = threadIdx.y * 32;
213
-
214
- for (int i = 0; i < 2; i++)
215
- {
216
- for (int j = 0; j < 2; j++)
217
- {
218
- wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
219
- wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
220
- wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
221
- wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
222
- wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
223
- wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
224
- }
225
- }
226
-
227
- for (int i = 0; i < 2; i++)
228
- {
229
- for (int j = 0; j < 2; j++)
230
- {
231
- for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
232
- {
233
- tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
234
- tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
235
- b_frag_real[i][j].x[k] = tmp_real;
236
- b_frag_imag[i][j].x[k] = tmp_imag;
237
- }
238
- }
239
- }
240
-
241
- for (int i = 0; i < 2; i++)
242
- {
243
- for (int j = 0; j < 2; j++)
244
- {
245
- wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
246
-
247
- // bd
248
- for (int k = 0; k < 2; k++)
249
- {
250
- wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
251
- }
252
-
253
- for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
254
- {
255
- acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
256
- }
257
- }
258
- }
259
-
260
- for (int i = 0; i < 2; i++)
261
- {
262
- for (int j = 0; j < 2; j++)
263
- {
264
- // ac - bd
265
- for (int k = 0; k < 2; k++)
266
- {
267
- wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
268
- }
269
- }
270
- }
271
-
272
- for (int i = 0; i < 2; i++)
273
- {
274
- for (int j = 0; j < 2; j++)
275
- {
276
- wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
277
- }
278
- }
279
- }
280
-
281
- __syncthreads();
282
-
283
- #pragma unroll
284
- for (int i = 0; i < n; i++)
285
- {
286
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
287
- if(out_gate != nullptr){
288
- out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
289
- }else{
290
- out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
291
- }
292
- }
293
- }
294
-
295
-
296
- __global__ void butterfly_ifft_bf16_cuda_kernel_128(
297
- const __nv_bfloat162 *__restrict__ x_real,
298
- const __nv_bfloat162 *__restrict__ x_imag,
299
- const __nv_bfloat162 *__restrict__ d_f_real,
300
- const __nv_bfloat162 *__restrict__ d_f_imag,
301
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
302
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
303
- __nv_bfloat162 *__restrict__ out_real,
304
- __nv_bfloat162 *__restrict__ out_gate,
305
- uint B,
306
- uint H,
307
- int N)
308
- {
309
- const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
310
- const int tw_offset = blockIdx.x * 64 + threadIdx.x;
311
- int idx;
312
- int shared_offset;
313
- const int B_Y = blockDim.y;
314
- const int n = N / B_Y;
315
-
316
- extern __shared__ __nv_bfloat16 real_shared[];
317
- __nv_bfloat16 *imag_shared = &real_shared[128 * 128];
318
- __nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
319
- __nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
320
-
321
- __nv_bfloat16 tmp_real, tmp_imag;
322
-
323
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[8][8];
324
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
325
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
326
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
327
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
328
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
329
-
330
- for (int i = 0; i < n; i++)
331
- {
332
- for(int j=0; j< 2; j++){
333
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
334
- reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
335
- reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
336
- }
337
- }
338
-
339
- for (int i = 0; i < n; i++)
340
- {
341
- for(int j=0; j< 2; j++){
342
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
343
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
344
- reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
345
- reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
346
- }
347
- }
348
-
349
- __syncthreads();
350
-
351
-
352
- for (int i = 0; i < 8; i++){
353
- wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
354
- wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
355
- }
356
-
357
- __syncthreads();
358
-
359
- for (int t = 0; t < 16; t++)
360
- {
361
- for (int i = 0; i < 8; i++){
362
- for (int j = 0; j < 8; j++){
363
- wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
364
- }
365
- }
366
-
367
- for (int i = 0; i < n; i++)
368
- {
369
- for(int j=0; j< 2; j++){
370
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
371
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
372
- reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[offset + idx];
373
- reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[offset + idx];
374
- }
375
- }
376
-
377
- __syncthreads();
378
-
379
- for (int i = 0; i < 8; i++)
380
- {
381
- wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
382
- wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
383
- }
384
-
385
-
386
- for (int j = 0; j < 8; j++)
387
- {
388
- for (int k = 0; k < tw_frag_real[j].num_elements; k++)
389
- {
390
- tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
391
- tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
392
- b_frag_real[j].x[k] = tmp_real;
393
- b_frag_imag[j].x[k] = tmp_imag;
394
- }
395
- }
396
-
397
- for (int i = 0; i < 8; i++)
398
- {
399
- wmma::fill_fragment(acc_frag_real[i], 0.0f);
400
-
401
- // bd
402
- #pragma unroll
403
- for (int k = 0; k < 8; k++)
404
- {
405
- wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
406
- }
407
-
408
- for (int k = 0; k < acc_frag_real[i].num_elements; k++)
409
- {
410
- acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
411
- }
412
- }
413
-
414
- for (int i = 0; i < 8; i++){
415
- for (int j = 0; j < 8; j++){
416
- wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
417
- }
418
- }
419
-
420
- for (int i = 0; i < 8; i++)
421
- {
422
- // ac - bd
423
- #pragma unroll
424
- for (int k = 0; k < 8; k++)
425
- {
426
- wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
427
- }
428
- }
429
-
430
- #pragma unroll
431
- for (int i = 0; i < 8; i++)
432
- {
433
- //wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
434
- wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
435
- }
436
-
437
- __syncthreads();
438
-
439
- #pragma unroll
440
- for (int i = 0; i < n; i++)
441
- {
442
- for(int j=0; j< 2; j++){
443
- idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
444
- shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
445
- if(out_gate != nullptr){
446
- out_real[offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[offset + idx]);
447
- }else{
448
- out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
449
- }
450
- }
451
- }
452
-
453
- __syncthreads();
454
- }
455
- }
456
-
457
- __global__ void butterfly_ifft_bf16_cuda_kernel_16(
458
- const __nv_bfloat162 *__restrict__ x_real,
459
- const __nv_bfloat162 *__restrict__ x_imag,
460
- const __nv_bfloat16 *__restrict__ d_f_real,
461
- const __nv_bfloat16 *__restrict__ d_f_imag,
462
- const __nv_bfloat162 *__restrict__ twiddle_factors_real,
463
- const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
464
- __nv_bfloat162 *__restrict__ out_real,
465
- __nv_bfloat162 *__restrict__ out_gate,
466
- uint B,
467
- uint H,
468
- int N)
469
- {
470
- const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
471
- const int tw_offset = blockIdx.x * 32 + threadIdx.x;
472
- int idx;
473
- int shared_offset;
474
- const int B_Y = blockDim.y;
475
- const int n = N / B_Y;
476
-
477
- __shared__ __nv_bfloat16 x_real_shared[16 * 64];
478
- __shared__ __nv_bfloat16 x_imag_shared[16 * 64];
479
- __shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
480
- __shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
481
- __shared__ float out_real_shared[16 * 64];
482
-
483
- // #pragma unroll
484
- for (int i = 0; i < n; i++)
485
- {
486
- idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
487
- shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
488
- reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
489
- reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
490
- reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
491
- reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
492
- }
493
-
494
- __syncthreads();
495
-
496
- if (threadIdx.y < 4)
497
- {
498
- __nv_bfloat16 tmp_real, tmp_imag;
499
-
500
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
501
- wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
502
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
503
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
504
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
505
- wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
506
- wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
507
-
508
- wmma::load_matrix_sync(a_frag_real, d_f_real, N);
509
- wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
510
- wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
511
- wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
512
- wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
513
- wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
514
-
515
-
516
- for (int k = 0; k < tw_frag_real.num_elements; k++)
517
- {
518
- tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
519
- tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
520
- b_frag_real.x[k] = tmp_real;
521
- b_frag_imag.x[k] = tmp_imag;
522
- }
523
-
524
-
525
-
526
- wmma::fill_fragment(acc_frag_real, 0.0f);
527
-
528
- wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
529
-
530
- for(int k=0; k< acc_frag_real.num_elements; k++){
531
- acc_frag_real.x[k] = - acc_frag_real.x[k];
532
- }
533
-
534
- wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
535
-
536
- wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
537
-
538
- }
539
-
540
- __syncthreads();
541
-
542
- #pragma unroll
543
- for (int i = 0; i < n; i++)
544
- {
545
- idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
546
- if(out_gate != nullptr){
547
- out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
548
- }else{
549
- out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
550
- }
551
- }
552
- }
553
-
554
-
555
- torch::Tensor butterfly_ifft_bf16_cuda(
556
- torch::Tensor x_real,
557
- torch::Tensor x_imag,
558
- torch::Tensor d_f_real,
559
- torch::Tensor d_f_imag,
560
- torch::Tensor twiddle_factors_real,
561
- torch::Tensor twiddle_factors_imag,
562
- std::optional<at::Tensor> out_gate = std::nullopt
563
- )
564
- {
565
-
566
- uint B = x_real.size(0);
567
- uint H = x_real.size(1);
568
- // uint m = x.size(1);
569
-
570
- // const int TILE_SIZE = 16;
571
-
572
- dim3 gridDim;
573
- dim3 blockDim;
574
-
575
- uint N = x_real.size(2);
576
- uint M = x_real.size(3);
577
- gridDim.y = B;
578
-
579
- blockDim.x = 32;
580
- blockDim.y = 4;
581
-
582
- torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
583
-
584
-
585
- //set blockDims
586
- switch(N){
587
- case 128:
588
- blockDim.x = 32;
589
- blockDim.y = 8;
590
- break;
591
- default:
592
- blockDim.x = 32;
593
- blockDim.y = 4;
594
- break;
595
- }
596
-
597
- //set gridDim.x
598
- switch(N){
599
- case 128:
600
- switch (M){
601
- case 16384:
602
- gridDim.x = 128;
603
- break;
604
- case 8192:
605
- gridDim.x = 64;
606
- break;
607
- case 4096:
608
- gridDim.x = 32;
609
- break;
610
- default:
611
- gridDim.x = 256;
612
- break;
613
- }
614
- break;
615
- default:
616
- switch (M){
617
- case 16384:
618
- gridDim.x = 256;
619
- break;
620
- case 8192:
621
- gridDim.x = 128;
622
- break;
623
- case 4096:
624
- gridDim.x = 64;
625
- break;
626
- default:
627
- gridDim.x = 512;
628
- break;
629
- }
630
- break;
631
- }
632
-
633
-
634
- switch (N)
635
- {
636
- case 16:
637
- gridDim.z = H;
638
- butterfly_ifft_bf16_cuda_kernel_16<<<gridDim, blockDim>>>(
639
- static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
640
- static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
641
- static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
642
- static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
643
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
644
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
645
- static_cast<__nv_bfloat162 *>(out.data_ptr()),
646
- out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
647
- B,
648
- H,
649
- N);
650
- break;
651
-
652
- case 32:
653
- gridDim.z = H;
654
- butterfly_ifft_bf16_cuda_kernel_32<<<gridDim, blockDim>>>(
655
- static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
656
- static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
657
- static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
658
- static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
659
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
660
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
661
- static_cast<__nv_bfloat162 *>(out.data_ptr()),
662
- out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
663
- B,
664
- H,
665
- N);
666
- break;
667
- case 64:
668
- gridDim.z = H / 16;
669
- cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
670
- butterfly_ifft_bf16_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
671
- static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
672
- static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
673
- static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
674
- static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
675
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
676
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
677
- static_cast<__nv_bfloat162 *>(out.data_ptr()),
678
- out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
679
- B,
680
- H,
681
- N);
682
- break;
683
-
684
- case 128:
685
- gridDim.z = H / 16;
686
- cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
687
- butterfly_ifft_bf16_cuda_kernel_128<<<gridDim, blockDim, 65536 * 2>>>(
688
- static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
689
- static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
690
- static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
691
- static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
692
- static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
693
- static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
694
- static_cast<__nv_bfloat162 *>(out.data_ptr()),
695
- out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
696
- B,
697
- H,
698
- N);
699
- break;
700
- default:
701
- printf("Not implemented\n");
702
- }
703
-
704
- return out;
705
- }
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include <cuda_runtime.h>
11
+ #include "shared.h"
12
+
13
+ using namespace nvcuda;
14
+
15
+ __global__ void butterfly_ifft_bf16_cuda_kernel_64(
16
+ const __nv_bfloat162 *__restrict__ x_real,
17
+ const __nv_bfloat162 *__restrict__ x_imag,
18
+ const __nv_bfloat162 *__restrict__ d_f_real,
19
+ const __nv_bfloat162 *__restrict__ d_f_imag,
20
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
+ __nv_bfloat162 *__restrict__ out_real,
23
+ __nv_bfloat162 *__restrict__ out_gate,
24
+ uint B,
25
+ uint H,
26
+ int N)
27
+ {
28
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
29
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
30
+ int idx;
31
+ int shared_offset;
32
+ const int B_Y = blockDim.y;
33
+ const int n = N / B_Y;
34
+
35
+ extern __shared__ __nv_bfloat16 x_real_shared[];
36
+ __nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
37
+ __nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
38
+ __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
39
+ __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
40
+ __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
41
+ float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
42
+
43
+ __nv_bfloat16 tmp_real, tmp_imag;
44
+
45
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4][4];
46
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4][4];
47
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
48
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
49
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
50
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
51
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
52
+
53
+ // #pragma unroll
54
+ for (int i = 0; i < n; i++)
55
+ {
56
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
57
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
58
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
59
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
60
+
61
+ // #pragma unroll
62
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
63
+ reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
64
+ reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
65
+ }
66
+
67
+ __syncthreads();
68
+
69
+ for (int i = 0; i < 4; i++)
70
+ {
71
+ #pragma unroll
72
+ for (int j = 0; j < 4; j++)
73
+ {
74
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
75
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
76
+ }
77
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
78
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
79
+ }
80
+
81
+ for (int t = 0; t < 16; t++)
82
+ {
83
+
84
+ for (int i = 0; i < n; i++)
85
+ {
86
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
87
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
88
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
89
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
90
+ }
91
+
92
+ __syncthreads();
93
+
94
+ for (int i = 0; i < 4; i++)
95
+ {
96
+ wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
97
+ wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
98
+ }
99
+
100
+ for (int j = 0; j < 4; j++)
101
+ {
102
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
103
+ {
104
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
105
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
106
+ b_frag_real[j].x[k] = tmp_real;
107
+ b_frag_imag[j].x[k] = tmp_imag;
108
+ }
109
+ }
110
+
111
+ for (int i = 0; i < 4; i++)
112
+ {
113
+ wmma::fill_fragment(acc_frag_real[i], 0.0f);
114
+
115
+ // bd
116
+ #pragma unroll
117
+ for (int k = 0; k < 4; k++)
118
+ {
119
+ wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
120
+ }
121
+
122
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
123
+ {
124
+ acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
125
+ }
126
+ }
127
+
128
+ for (int i = 0; i < 4; i++)
129
+ {
130
+ // ac - bd
131
+ #pragma unroll
132
+ for (int k = 0; k < 4; k++)
133
+ {
134
+ wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
135
+ }
136
+ }
137
+
138
+ #pragma unroll
139
+ for (int i = 0; i < 4; i++)
140
+ {
141
+ wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
142
+ }
143
+
144
+ __syncthreads();
145
+
146
+ #pragma unroll
147
+ for (int i = 0; i < n; i++)
148
+ {
149
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
150
+ if(out_gate != nullptr){
151
+ out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]); ;
152
+ }else{
153
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
154
+ }
155
+ }
156
+
157
+ __syncthreads();
158
+ }
159
+ }
160
+
161
+ __global__ void butterfly_ifft_bf16_cuda_kernel_32(
162
+ const __nv_bfloat162 *__restrict__ x_real,
163
+ const __nv_bfloat162 *__restrict__ x_imag,
164
+ const __nv_bfloat16 *__restrict__ d_f_real,
165
+ const __nv_bfloat16 *__restrict__ d_f_imag,
166
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
167
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
168
+ __nv_bfloat162 *__restrict__ out_real,
169
+ __nv_bfloat162 *__restrict__ out_gate,
170
+ uint B,
171
+ uint H,
172
+ int N)
173
+ {
174
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
175
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
176
+ int idx;
177
+ int shared_offset;
178
+ const int B_Y = blockDim.y;
179
+ const int n = N / B_Y;
180
+
181
+ __shared__ __nv_bfloat16 x_real_shared[32 * 64];
182
+ __shared__ __nv_bfloat16 x_imag_shared[32 * 64];
183
+ __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
184
+ __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
185
+ __shared__ float out_real_shared[32 * 64];
186
+
187
+ // #pragma unroll
188
+ for (int i = 0; i < n; i++)
189
+ {
190
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
191
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
192
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
193
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
194
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
195
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
196
+ }
197
+
198
+ __syncthreads();
199
+
200
+ if (threadIdx.y < N / 16)
201
+ {
202
+ __nv_bfloat16 tmp_real, tmp_imag;
203
+
204
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
205
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
206
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
207
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
208
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
209
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
210
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
211
+
212
+ int t = threadIdx.y * 32;
213
+
214
+ for (int i = 0; i < 2; i++)
215
+ {
216
+ for (int j = 0; j < 2; j++)
217
+ {
218
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
219
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
220
+ wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
221
+ wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
222
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
223
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
224
+ }
225
+ }
226
+
227
+ for (int i = 0; i < 2; i++)
228
+ {
229
+ for (int j = 0; j < 2; j++)
230
+ {
231
+ for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
232
+ {
233
+ tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
234
+ tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
235
+ b_frag_real[i][j].x[k] = tmp_real;
236
+ b_frag_imag[i][j].x[k] = tmp_imag;
237
+ }
238
+ }
239
+ }
240
+
241
+ for (int i = 0; i < 2; i++)
242
+ {
243
+ for (int j = 0; j < 2; j++)
244
+ {
245
+ wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
246
+
247
+ // bd
248
+ for (int k = 0; k < 2; k++)
249
+ {
250
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
251
+ }
252
+
253
+ for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
254
+ {
255
+ acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
256
+ }
257
+ }
258
+ }
259
+
260
+ for (int i = 0; i < 2; i++)
261
+ {
262
+ for (int j = 0; j < 2; j++)
263
+ {
264
+ // ac - bd
265
+ for (int k = 0; k < 2; k++)
266
+ {
267
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
268
+ }
269
+ }
270
+ }
271
+
272
+ for (int i = 0; i < 2; i++)
273
+ {
274
+ for (int j = 0; j < 2; j++)
275
+ {
276
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
277
+ }
278
+ }
279
+ }
280
+
281
+ __syncthreads();
282
+
283
+ #pragma unroll
284
+ for (int i = 0; i < n; i++)
285
+ {
286
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
287
+ if(out_gate != nullptr){
288
+ out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
289
+ }else{
290
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
291
+ }
292
+ }
293
+ }
294
+
295
+
296
+ __global__ void butterfly_ifft_bf16_cuda_kernel_128(
297
+ const __nv_bfloat162 *__restrict__ x_real,
298
+ const __nv_bfloat162 *__restrict__ x_imag,
299
+ const __nv_bfloat162 *__restrict__ d_f_real,
300
+ const __nv_bfloat162 *__restrict__ d_f_imag,
301
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
302
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
303
+ __nv_bfloat162 *__restrict__ out_real,
304
+ __nv_bfloat162 *__restrict__ out_gate,
305
+ uint B,
306
+ uint H,
307
+ int N)
308
+ {
309
+ const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
310
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
311
+ int idx;
312
+ int shared_offset;
313
+ const int B_Y = blockDim.y;
314
+ const int n = N / B_Y;
315
+
316
+ extern __shared__ __nv_bfloat16 real_shared[];
317
+ __nv_bfloat16 *imag_shared = &real_shared[128 * 128];
318
+ __nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
319
+ __nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
320
+
321
+ __nv_bfloat16 tmp_real, tmp_imag;
322
+
323
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[8][8];
324
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
325
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
326
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
327
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
328
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
329
+
330
+ for (int i = 0; i < n; i++)
331
+ {
332
+ for(int j=0; j< 2; j++){
333
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
334
+ reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
335
+ reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
336
+ }
337
+ }
338
+
339
+ for (int i = 0; i < n; i++)
340
+ {
341
+ for(int j=0; j< 2; j++){
342
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
343
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
344
+ reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
345
+ reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
346
+ }
347
+ }
348
+
349
+ __syncthreads();
350
+
351
+
352
+ for (int i = 0; i < 8; i++){
353
+ wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
354
+ wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
355
+ }
356
+
357
+ __syncthreads();
358
+
359
+ for (int t = 0; t < 16; t++)
360
+ {
361
+ for (int i = 0; i < 8; i++){
362
+ for (int j = 0; j < 8; j++){
363
+ wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
364
+ }
365
+ }
366
+
367
+ for (int i = 0; i < n; i++)
368
+ {
369
+ for(int j=0; j< 2; j++){
370
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
371
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
372
+ reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[offset + idx];
373
+ reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[offset + idx];
374
+ }
375
+ }
376
+
377
+ __syncthreads();
378
+
379
+ for (int i = 0; i < 8; i++)
380
+ {
381
+ wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
382
+ wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
383
+ }
384
+
385
+
386
+ for (int j = 0; j < 8; j++)
387
+ {
388
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
389
+ {
390
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
391
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
392
+ b_frag_real[j].x[k] = tmp_real;
393
+ b_frag_imag[j].x[k] = tmp_imag;
394
+ }
395
+ }
396
+
397
+ for (int i = 0; i < 8; i++)
398
+ {
399
+ wmma::fill_fragment(acc_frag_real[i], 0.0f);
400
+
401
+ // bd
402
+ #pragma unroll
403
+ for (int k = 0; k < 8; k++)
404
+ {
405
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
406
+ }
407
+
408
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
409
+ {
410
+ acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
411
+ }
412
+ }
413
+
414
+ for (int i = 0; i < 8; i++){
415
+ for (int j = 0; j < 8; j++){
416
+ wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
417
+ }
418
+ }
419
+
420
+ for (int i = 0; i < 8; i++)
421
+ {
422
+ // ac - bd
423
+ #pragma unroll
424
+ for (int k = 0; k < 8; k++)
425
+ {
426
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
427
+ }
428
+ }
429
+
430
+ #pragma unroll
431
+ for (int i = 0; i < 8; i++)
432
+ {
433
+ //wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
434
+ wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
435
+ }
436
+
437
+ __syncthreads();
438
+
439
+ #pragma unroll
440
+ for (int i = 0; i < n; i++)
441
+ {
442
+ for(int j=0; j< 2; j++){
443
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
444
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
445
+ if(out_gate != nullptr){
446
+ out_real[offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[offset + idx]);
447
+ }else{
448
+ out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
449
+ }
450
+ }
451
+ }
452
+
453
+ __syncthreads();
454
+ }
455
+ }
456
+
457
+ __global__ void butterfly_ifft_bf16_cuda_kernel_16(
458
+ const __nv_bfloat162 *__restrict__ x_real,
459
+ const __nv_bfloat162 *__restrict__ x_imag,
460
+ const __nv_bfloat16 *__restrict__ d_f_real,
461
+ const __nv_bfloat16 *__restrict__ d_f_imag,
462
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
463
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
464
+ __nv_bfloat162 *__restrict__ out_real,
465
+ __nv_bfloat162 *__restrict__ out_gate,
466
+ uint B,
467
+ uint H,
468
+ int N)
469
+ {
470
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
471
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
472
+ int idx;
473
+ int shared_offset;
474
+ const int B_Y = blockDim.y;
475
+ const int n = N / B_Y;
476
+
477
+ __shared__ __nv_bfloat16 x_real_shared[16 * 64];
478
+ __shared__ __nv_bfloat16 x_imag_shared[16 * 64];
479
+ __shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
480
+ __shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
481
+ __shared__ float out_real_shared[16 * 64];
482
+
483
+ // #pragma unroll
484
+ for (int i = 0; i < n; i++)
485
+ {
486
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
487
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
488
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
489
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
490
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
491
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
492
+ }
493
+
494
+ __syncthreads();
495
+
496
+ if (threadIdx.y < 4)
497
+ {
498
+ __nv_bfloat16 tmp_real, tmp_imag;
499
+
500
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
501
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
502
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
503
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
504
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
505
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
506
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
507
+
508
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
509
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
510
+ wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
511
+ wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
512
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
513
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
514
+
515
+
516
+ for (int k = 0; k < tw_frag_real.num_elements; k++)
517
+ {
518
+ tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
519
+ tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
520
+ b_frag_real.x[k] = tmp_real;
521
+ b_frag_imag.x[k] = tmp_imag;
522
+ }
523
+
524
+
525
+
526
+ wmma::fill_fragment(acc_frag_real, 0.0f);
527
+
528
+ wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
529
+
530
+ for(int k=0; k< acc_frag_real.num_elements; k++){
531
+ acc_frag_real.x[k] = - acc_frag_real.x[k];
532
+ }
533
+
534
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
535
+
536
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
537
+
538
+ }
539
+
540
+ __syncthreads();
541
+
542
+ #pragma unroll
543
+ for (int i = 0; i < n; i++)
544
+ {
545
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
546
+ if(out_gate != nullptr){
547
+ out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
548
+ }else{
549
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
550
+ }
551
+ }
552
+ }
553
+
554
+
555
+ torch::Tensor butterfly_ifft_bf16_cuda(
556
+ torch::Tensor x_real,
557
+ torch::Tensor x_imag,
558
+ torch::Tensor d_f_real,
559
+ torch::Tensor d_f_imag,
560
+ torch::Tensor twiddle_factors_real,
561
+ torch::Tensor twiddle_factors_imag,
562
+ std::optional<at::Tensor> out_gate = std::nullopt
563
+ )
564
+ {
565
+
566
+ uint B = x_real.size(0);
567
+ uint H = x_real.size(1);
568
+ // uint m = x.size(1);
569
+
570
+ // const int TILE_SIZE = 16;
571
+
572
+ dim3 gridDim;
573
+ dim3 blockDim;
574
+
575
+ uint N = x_real.size(2);
576
+ uint M = x_real.size(3);
577
+ gridDim.y = B;
578
+
579
+ blockDim.x = 32;
580
+ blockDim.y = 4;
581
+
582
+ torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
583
+
584
+
585
+ //set blockDims
586
+ switch(N){
587
+ case 128:
588
+ blockDim.x = 32;
589
+ blockDim.y = 8;
590
+ break;
591
+ default:
592
+ blockDim.x = 32;
593
+ blockDim.y = 4;
594
+ break;
595
+ }
596
+
597
+ //set gridDim.x
598
+ switch(N){
599
+ case 128:
600
+ switch (M){
601
+ case 16384:
602
+ gridDim.x = 128;
603
+ break;
604
+ case 8192:
605
+ gridDim.x = 64;
606
+ break;
607
+ case 4096:
608
+ gridDim.x = 32;
609
+ break;
610
+ default:
611
+ gridDim.x = 256;
612
+ break;
613
+ }
614
+ break;
615
+ default:
616
+ switch (M){
617
+ case 16384:
618
+ gridDim.x = 256;
619
+ break;
620
+ case 8192:
621
+ gridDim.x = 128;
622
+ break;
623
+ case 4096:
624
+ gridDim.x = 64;
625
+ break;
626
+ default:
627
+ gridDim.x = 512;
628
+ break;
629
+ }
630
+ break;
631
+ }
632
+
633
+
634
+ switch (N)
635
+ {
636
+ case 16:
637
+ gridDim.z = H;
638
+ butterfly_ifft_bf16_cuda_kernel_16<<<gridDim, blockDim>>>(
639
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
640
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
641
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
642
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
643
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
644
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
645
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
646
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
647
+ B,
648
+ H,
649
+ N);
650
+ break;
651
+
652
+ case 32:
653
+ gridDim.z = H;
654
+ butterfly_ifft_bf16_cuda_kernel_32<<<gridDim, blockDim>>>(
655
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
656
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
657
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
658
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
659
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
660
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
661
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
662
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
663
+ B,
664
+ H,
665
+ N);
666
+ break;
667
+ case 64:
668
+ gridDim.z = H / 16;
669
+ cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
670
+ butterfly_ifft_bf16_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
671
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
672
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
673
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
674
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
675
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
676
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
677
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
678
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
679
+ B,
680
+ H,
681
+ N);
682
+ break;
683
+
684
+ case 128:
685
+ gridDim.z = H / 16;
686
+ cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
687
+ butterfly_ifft_bf16_cuda_kernel_128<<<gridDim, blockDim, 65536 * 2>>>(
688
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
689
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
690
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
691
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
692
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
693
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
694
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
695
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
696
+ B,
697
+ H,
698
+ N);
699
+ break;
700
+ default:
701
+ printf("Not implemented\n");
702
+ }
703
+
704
+ return out;
705
+ }