Spaces:
Runtime error
Runtime error
Update Feather h200 training runtime image
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +16 -20
- Dockerfile +118 -122
- entrypoint.py +267 -227
- mamba_ssm_init.py +62 -94
- overlay/.dockerignore +20 -20
- overlay/configs/__init__.py +5 -5
- overlay/configs/hardware_config.py +104 -104
- overlay/configs/harness_config.py +63 -63
- overlay/configs/model_config.py +80 -80
- overlay/harness/__init__.py +21 -21
- overlay/harness/eval_agent.py +129 -257
- overlay/harness/git_utils.py +94 -94
- overlay/harness/health_monitor.py +86 -86
- overlay/harness/meta_agent.py +139 -139
- overlay/harness/orchestrator.py +281 -284
- overlay/harness/search_strategy.py +153 -153
- overlay/htm_rust/Cargo.lock +383 -383
- overlay/htm_rust/Cargo.toml +37 -37
- overlay/htm_rust/build.rs +168 -160
- overlay/htm_rust/pyproject.toml +17 -17
- overlay/htm_rust/src/gpu/fused.rs +702 -663
- overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu +677 -677
- overlay/htm_rust/src/gpu/tests.rs +663 -643
- overlay/htm_rust/src/lib.rs +198 -198
- overlay/htm_rust/src/region.rs +94 -94
- overlay/htm_rust/src/sp.rs +302 -302
- overlay/htm_rust/src/tm.rs +545 -545
- overlay/hydra/__init__.py +37 -31
- overlay/hydra/config.py +225 -220
- overlay/hydra/data_module.py +288 -288
- overlay/hydra/diffusion_loss.py +236 -236
- overlay/hydra/engram.py +160 -175
- overlay/hydra/eval.py +210 -217
- overlay/hydra/gdn_block.py +126 -126
- overlay/hydra/hyena_block.py +68 -68
- overlay/hydra/lightning_module.py +326 -326
- overlay/hydra/model.py +0 -0
- overlay/hydra/optimizer.py +252 -252
- overlay/hydra/reality_bridge.py +71 -0
- overlay/hydra/training.py +965 -946
- overlay/kernels/cuda/decode_kernels.cu +10 -10
- overlay/kernels/cuda/flashfftconv/LICENSE +201 -201
- overlay/kernels/cuda/flashfftconv/README.md +57 -57
- overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT +1 -1
- overlay/kernels/cuda/flashfftconv/csrc/.gitignore +9 -9
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h +373 -373
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu +698 -698
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu +724 -724
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu +723 -723
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu +705 -705
.dockerignore
CHANGED
|
@@ -1,20 +1,16 @@
|
|
| 1 |
-
.
|
| 2 |
-
|
| 3 |
-
.
|
| 4 |
-
.
|
| 5 |
-
.
|
| 6 |
-
.
|
| 7 |
-
|
| 8 |
-
*
|
| 9 |
-
*
|
| 10 |
-
*.
|
| 11 |
-
*.
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
*.
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
overlay/data/
|
| 18 |
-
overlay/state_store/
|
| 19 |
-
overlay/htm_rust/target/
|
| 20 |
-
overlay/hydra-core/target/
|
|
|
|
| 1 |
+
# Keep HF runtime image context deterministic and small.
|
| 2 |
+
**/__pycache__/
|
| 3 |
+
**/*.py[cod]
|
| 4 |
+
**/.pytest_cache/
|
| 5 |
+
**/.mypy_cache/
|
| 6 |
+
**/.ruff_cache/
|
| 7 |
+
**/.venv/
|
| 8 |
+
**/target/
|
| 9 |
+
**/logs/
|
| 10 |
+
**/*.log
|
| 11 |
+
**/*.out
|
| 12 |
+
**/*.pt
|
| 13 |
+
**/*.safetensors
|
| 14 |
+
**/*.parquet
|
| 15 |
+
**/*.npz
|
| 16 |
+
**/.git/
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
CHANGED
|
@@ -1,128 +1,124 @@
|
|
| 1 |
-
FROM pytorch/pytorch:2.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
#
|
| 39 |
-
#
|
| 40 |
-
#
|
| 41 |
-
#
|
| 42 |
-
#
|
| 43 |
-
#
|
| 44 |
-
#
|
| 45 |
-
#
|
| 46 |
-
#
|
| 47 |
-
#
|
| 48 |
-
#
|
| 49 |
-
#
|
| 50 |
-
#
|
| 51 |
-
#
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
#
|
| 63 |
-
#
|
| 64 |
-
#
|
| 65 |
-
#
|
| 66 |
-
#
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
#
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
#
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
if norm != raw:
|
| 116 |
-
sh.write_bytes(norm)
|
| 117 |
-
PY
|
| 118 |
|
| 119 |
RUN python -m py_compile hydra/training.py prepare.py train.py && \
|
| 120 |
bash -n scripts/run_domain_expanded_pretrain.sh
|
| 121 |
-
|
| 122 |
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
maturin build --release -j 1 --features gpu --manifest-path htm_rust/Cargo.toml && \
|
| 126 |
pip install htm_rust/target/wheels/htm_rust-*.whl
|
| 127 |
-
|
| 128 |
-
CMD ["python", "/app/entrypoint.py"]
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel
|
| 2 |
|
| 3 |
+
# Default target is HF Jobs a10g-large (NVIDIA A10G, Ampere GA102, sm_86).
|
| 4 |
+
# Override at build time for other cards, e.g. --build-arg FEATHER_GPU_ARCH=sm_90a.
|
| 5 |
+
ARG FEATHER_GPU_ARCH=sm_86
|
| 6 |
+
ARG FEATHER_TORCH_CUDA_ARCH_LIST=8.6
|
| 7 |
+
|
| 8 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 9 |
+
PIP_NO_CACHE_DIR=1 \
|
| 10 |
+
PYTHONUNBUFFERED=1 \
|
| 11 |
+
CARGO_HOME=/root/.cargo \
|
| 12 |
+
RUSTUP_HOME=/root/.rustup \
|
| 13 |
+
HTM_CUDA_ARCH=${FEATHER_GPU_ARCH} \
|
| 14 |
+
TORCH_CUDA_ARCH_LIST=${FEATHER_TORCH_CUDA_ARCH_LIST} \
|
| 15 |
+
PATH=/root/.cargo/bin:${PATH}
|
| 16 |
+
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
git curl ca-certificates build-essential pkg-config libssl-dev && \
|
| 19 |
+
rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
|
| 22 |
+
|
| 23 |
+
RUN pip install --upgrade pip setuptools wheel && \
|
| 24 |
+
pip install \
|
| 25 |
+
maturin \
|
| 26 |
+
huggingface_hub \
|
| 27 |
+
datasets \
|
| 28 |
+
requests \
|
| 29 |
+
pyarrow \
|
| 30 |
+
rustbpe \
|
| 31 |
+
pandas \
|
| 32 |
+
tiktoken \
|
| 33 |
+
pydantic \
|
| 34 |
+
ninja \
|
| 35 |
+
packaging \
|
| 36 |
+
einops
|
| 37 |
+
|
| 38 |
+
# Mamba-3 fused CUDA kernel stack (mandatory — NO fallback allowed).
|
| 39 |
+
#
|
| 40 |
+
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
|
| 41 |
+
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
|
| 42 |
+
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 —
|
| 43 |
+
# nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
|
| 44 |
+
#
|
| 45 |
+
# Wheel selection for base image pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel:
|
| 46 |
+
# - Python 3.11 (cp311) — matches PyTorch 2.5.1 image
|
| 47 |
+
# - CUDA 12.x wheels (cu12) — compatible with CUDA 12.1 base
|
| 48 |
+
# - PyTorch 2.5 ABI (torch2.5) — exact torch match
|
| 49 |
+
# - cxx11abiFALSE — standard PyTorch pip build
|
| 50 |
+
#
|
| 51 |
+
# Versions: mamba_ssm 2.3.0 + causal_conv1d 1.6.0 (matching torch2.5 ABI).
|
| 52 |
+
# Both are CUDA-compiled, no build toolchain needed
|
| 53 |
+
# on the Space builder.
|
| 54 |
+
#
|
| 55 |
+
# Step A: install the published v2.3.0 prebuilt wheel (compiled CUDA ops
|
| 56 |
+
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
|
| 57 |
+
RUN pip install \
|
| 58 |
+
'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.0/causal_conv1d-1.6.0+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
|
| 59 |
+
'https://github.com/state-spaces/mamba/releases/download/v2.3.0/mamba_ssm-2.3.0+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
|
| 60 |
+
python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
|
| 61 |
+
|
| 62 |
+
#
|
| 63 |
+
# Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
|
| 64 |
+
# main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
|
| 65 |
+
# files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
|
| 66 |
+
# zero compiled-CUDA dependencies (verified: every import in that subtree is
|
| 67 |
+
# triton/torch/python — no .so files, no nvcc). So we install the v2.3.1 wheel
|
| 68 |
+
# (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
|
| 69 |
+
#
|
| 70 |
+
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
|
| 71 |
+
# missing-file error the smoke hit on the last attempt.
|
| 72 |
+
# Download grafted mamba3 module + triton ops subtree
|
| 73 |
+
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
|
| 74 |
+
BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
|
| 75 |
+
curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
|
| 76 |
+
mkdir -p "$SITE/ops/triton/mamba3" && \
|
| 77 |
+
for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
|
| 78 |
+
curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
|
| 79 |
+
done
|
| 80 |
+
|
| 81 |
+
# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
|
| 82 |
+
# (pure-Triton, works). The shipped __init__.py eagerly imports
|
| 83 |
+
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
|
| 84 |
+
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
|
| 85 |
+
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
|
| 86 |
+
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
|
| 87 |
+
|
| 88 |
+
# Structural check (no triton init — triton has no GPU on the builder)
|
| 89 |
+
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
|
| 90 |
+
test -f "$SITE/modules/mamba3.py" && \
|
| 91 |
+
test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
|
| 92 |
+
test -s "$SITE/__init__.py" && \
|
| 93 |
+
echo "mamba3 graft + __init__ override verified"
|
| 94 |
+
|
| 95 |
+
# Optional tilelang for MIMO path — pure-python, cheap; SISO Mamba3 works without.
|
| 96 |
+
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed — continuing"
|
| 97 |
+
|
| 98 |
+
# Triton version decision: FORCE 3.4.0 — first line with both mamba3
|
| 99 |
+
# APIs (set_allocator + tl.make_tensor_descriptor) while avoiding the 3.5.x
|
| 100 |
+
# driver-discovery regression seen on HF A10G (`0 active drivers` despite
|
| 101 |
+
# torch.cuda being available). torch 2.5's _inductor expects older Triton
|
| 102 |
+
# internals, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
|
| 103 |
+
# before any torch._inductor import path runs, so the incompatibility is
|
| 104 |
+
# neutralized. Build-time assert verifies mamba3's two required APIs.
|
| 105 |
+
RUN pip install --force-reinstall --no-deps 'triton==3.4.0' && \
|
| 106 |
+
python -c "import triton; from triton import language as tl; \
|
| 107 |
+
assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
|
| 108 |
+
assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
|
| 109 |
+
print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"
|
| 110 |
+
|
| 111 |
+
WORKDIR /workspace
|
| 112 |
+
COPY overlay /workspace/feather
|
| 113 |
+
COPY entrypoint.py /app/entrypoint.py
|
| 114 |
+
WORKDIR /workspace/feather
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
RUN python -m py_compile hydra/training.py prepare.py train.py && \
|
| 117 |
bash -n scripts/run_domain_expanded_pretrain.sh
|
| 118 |
+
|
| 119 |
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
|
| 120 |
+
echo "building htm_rust GPU kernels for HTM_CUDA_ARCH=${HTM_CUDA_ARCH} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}" && \
|
| 121 |
+
maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
|
|
|
|
| 122 |
pip install htm_rust/target/wheels/htm_rust-*.whl
|
| 123 |
+
|
| 124 |
+
CMD ["python", "/app/entrypoint.py"]
|
entrypoint.py
CHANGED
|
@@ -1,227 +1,267 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
import subprocess
|
| 7 |
-
import sys
|
| 8 |
-
import time
|
| 9 |
-
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from threading import Thread
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
#
|
| 29 |
-
#
|
| 30 |
-
#
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from threading import Thread
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _prepend_library_path(*paths: str) -> None:
|
| 15 |
+
"""Expose injected NVIDIA driver libraries before torch/triton imports."""
|
| 16 |
+
existing = [p for p in os.environ.get('LD_LIBRARY_PATH', '').split(':') if p]
|
| 17 |
+
merged = []
|
| 18 |
+
for p in paths:
|
| 19 |
+
if p and p not in merged:
|
| 20 |
+
merged.append(p)
|
| 21 |
+
for p in existing:
|
| 22 |
+
if p not in merged:
|
| 23 |
+
merged.append(p)
|
| 24 |
+
os.environ['LD_LIBRARY_PATH'] = ':'.join(merged)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_prepend_library_path(
|
| 28 |
+
# HF Jobs injects the host driver under /usr/local/nvidia. Prefer that
|
| 29 |
+
# over CUDA toolkit/compat libcuda stubs; using /usr/local/cuda/compat here
|
| 30 |
+
# made A10G PyTorch report Error 803 despite nvidia-smi working.
|
| 31 |
+
'/usr/local/nvidia/lib64',
|
| 32 |
+
'/usr/local/nvidia/lib',
|
| 33 |
+
'/usr/lib/x86_64-linux-gnu',
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# =============================================================================
|
| 38 |
+
# EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
|
| 39 |
+
# =============================================================================
|
| 40 |
+
# On HF GPU hosts, cudaGetDeviceCount can transiently return not-ready errors
|
| 41 |
+
# on first use. H200 fabric-manager is the worst case; A10G is usually ready
|
| 42 |
+
# immediately, but the same early kick keeps the runtime deterministic.
|
| 43 |
+
# synchronizes with the container's first driver call. Once any NVML/CUDA
|
| 44 |
+
# call succeeds once (even just nvidia-smi), the fabric is up for the rest
|
| 45 |
+
# of the container lifetime.
|
| 46 |
+
#
|
| 47 |
+
# Our previous approach (wait in a subprocess before training) didn't work
|
| 48 |
+
# because the "initialization failed" state persisted across calls in the
|
| 49 |
+
# same container. The real fix: kick the driver exactly once with
|
| 50 |
+
# nvidia-smi, which is what successfully-working baseline containers do
|
| 51 |
+
# implicitly via their first torch.cuda call.
|
| 52 |
+
#
|
| 53 |
+
# Must happen BEFORE `import torch` (because any import that eagerly calls
|
| 54 |
+
# cudaGetDeviceCount will cache the Error 802 state).
|
| 55 |
+
def _early_cuda_kick() -> None:
|
| 56 |
+
deadline = time.time() + 120.0
|
| 57 |
+
attempt = 0
|
| 58 |
+
while time.time() < deadline:
|
| 59 |
+
attempt += 1
|
| 60 |
+
r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
|
| 61 |
+
if r.returncode == 0:
|
| 62 |
+
gpu_line = next((ln.strip() for ln in (r.stdout or '').splitlines() if any(g in ln for g in ('A10', 'A100', 'H100', 'H200', 'RTX'))), 'gpu=unknown')
|
| 63 |
+
print(f'[boot] nvidia-smi OK on attempt {attempt}: {gpu_line}', flush=True)
|
| 64 |
+
break
|
| 65 |
+
print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
|
| 66 |
+
flush=True)
|
| 67 |
+
time.sleep(2)
|
| 68 |
+
# After nvidia-smi, probe torch in a subprocess so any latent error state
|
| 69 |
+
# doesn't leak into the main process's CUDA context.
|
| 70 |
+
probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
|
| 71 |
+
torch_deadline = time.time() + 120.0
|
| 72 |
+
t_attempt = 0
|
| 73 |
+
while time.time() < torch_deadline:
|
| 74 |
+
t_attempt += 1
|
| 75 |
+
r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
|
| 76 |
+
if r.returncode == 0:
|
| 77 |
+
print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
|
| 78 |
+
return
|
| 79 |
+
if t_attempt == 1:
|
| 80 |
+
print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
|
| 81 |
+
time.sleep(2)
|
| 82 |
+
print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
_early_cuda_kick()
|
| 86 |
+
|
| 87 |
+
# Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
|
| 88 |
+
# triton_cache_setup.py is copied next to this file by the job bash command.
|
| 89 |
+
try:
|
| 90 |
+
import triton_cache_setup as _tcs
|
| 91 |
+
_tcs.setup()
|
| 92 |
+
except ImportError:
|
| 93 |
+
print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
|
| 94 |
+
|
| 95 |
+
from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
|
| 96 |
+
|
| 97 |
+
REPO_ROOT = Path('/workspace/feather')
|
| 98 |
+
CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
|
| 99 |
+
LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
|
| 100 |
+
JOB_ID = os.environ.get('JOB_ID', 'local-job')
|
| 101 |
+
OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
|
| 102 |
+
TOKEN = os.environ.get('HF_TOKEN')
|
| 103 |
+
RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
|
| 104 |
+
APP_PORT = int(os.environ.get('PORT', '7860'))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class _HealthHandler(BaseHTTPRequestHandler):
|
| 108 |
+
def do_GET(self):
|
| 109 |
+
if self.path in ('/', '/health', '/healthz', '/ready'):
|
| 110 |
+
payload = {
|
| 111 |
+
'status': 'ok',
|
| 112 |
+
'mode': RUNTIME_MODE,
|
| 113 |
+
'job_id': JOB_ID,
|
| 114 |
+
}
|
| 115 |
+
body = json.dumps(payload).encode('utf-8')
|
| 116 |
+
self.send_response(200)
|
| 117 |
+
self.send_header('Content-Type', 'application/json')
|
| 118 |
+
self.send_header('Content-Length', str(len(body)))
|
| 119 |
+
self.end_headers()
|
| 120 |
+
self.wfile.write(body)
|
| 121 |
+
return
|
| 122 |
+
self.send_response(404)
|
| 123 |
+
self.end_headers()
|
| 124 |
+
|
| 125 |
+
def log_message(self, format, *args):
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _start_health_server() -> HTTPServer:
|
| 130 |
+
server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
|
| 131 |
+
thread = Thread(target=server.serve_forever, daemon=True)
|
| 132 |
+
thread.start()
|
| 133 |
+
print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
|
| 134 |
+
return server
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
|
| 138 |
+
if not path.exists():
|
| 139 |
+
print(f'[upload] skip missing {path}', flush=True)
|
| 140 |
+
return
|
| 141 |
+
api.upload_file(
|
| 142 |
+
path_or_fileobj=str(path),
|
| 143 |
+
path_in_repo=dest,
|
| 144 |
+
repo_id=OUTPUT_REPO,
|
| 145 |
+
repo_type='model',
|
| 146 |
+
)
|
| 147 |
+
print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
|
| 151 |
+
"""Block until CUDA is fully initialized or timeout.
|
| 152 |
+
|
| 153 |
+
On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
|
| 154 |
+
with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
|
| 155 |
+
(error 802) for the first few seconds, and any import that triggers
|
| 156 |
+
@triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
|
| 157 |
+
"0 active drivers" if it happens during that window.
|
| 158 |
+
|
| 159 |
+
We pre-init CUDA in a throwaway Python subprocess (so any error state does
|
| 160 |
+
not leak into the main training process) and retry until torch.cuda
|
| 161 |
+
reports ready.
|
| 162 |
+
"""
|
| 163 |
+
import time as _t
|
| 164 |
+
probe = (
|
| 165 |
+
"import torch; "
|
| 166 |
+
"import sys; "
|
| 167 |
+
"avail = torch.cuda.is_available(); "
|
| 168 |
+
"count = torch.cuda.device_count() if avail else 0; "
|
| 169 |
+
"torch.empty(1, device='cuda') if (avail and count > 0) else None; "
|
| 170 |
+
"from triton.runtime import driver; "
|
| 171 |
+
"driver.active.get_current_device(); "
|
| 172 |
+
"sys.exit(0 if (avail and count > 0) else 1)"
|
| 173 |
+
)
|
| 174 |
+
deadline = _t.time() + timeout_s
|
| 175 |
+
attempt = 0
|
| 176 |
+
while _t.time() < deadline:
|
| 177 |
+
attempt += 1
|
| 178 |
+
r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
|
| 179 |
+
if r.returncode == 0:
|
| 180 |
+
print(f'[job] CUDA/Triton ready after {attempt} probe(s)', flush=True)
|
| 181 |
+
return
|
| 182 |
+
if attempt == 1:
|
| 183 |
+
print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
|
| 184 |
+
_t.sleep(2)
|
| 185 |
+
print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def run_job_mode() -> int:
|
| 189 |
+
os.chdir(REPO_ROOT)
|
| 190 |
+
os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
|
| 191 |
+
os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
|
| 192 |
+
os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
|
| 193 |
+
os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
|
| 194 |
+
os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
|
| 195 |
+
os.environ.setdefault('FEATHER_GPU_PROFILE', 'a10g-large')
|
| 196 |
+
os.environ.setdefault('HTM_CUDA_ARCH', 'sm_86')
|
| 197 |
+
os.environ.setdefault('TORCH_CUDA_ARCH_LIST', '8.6')
|
| 198 |
+
os.environ.setdefault('TRITON_CACHE_DIR', f"/workspace/triton_cache/{os.environ['FEATHER_GPU_PROFILE']}")
|
| 199 |
+
os.environ.setdefault('TRITON_CACHE_REPO', f"icarus112/feather-triton-cache-{os.environ['FEATHER_GPU_PROFILE']}")
|
| 200 |
+
print(f"[job] gpu_profile={os.environ['FEATHER_GPU_PROFILE']} htm_cuda_arch={os.environ['HTM_CUDA_ARCH']} torch_cuda_arch={os.environ['TORCH_CUDA_ARCH_LIST']}", flush=True)
|
| 201 |
+
|
| 202 |
+
# CUDA readiness was kicked at module import via _early_cuda_kick. Keep
|
| 203 |
+
# the wait as a second safety net — no-op if CUDA already ready.
|
| 204 |
+
_wait_for_cuda_ready()
|
| 205 |
+
|
| 206 |
+
cmd = [
|
| 207 |
+
'bash',
|
| 208 |
+
'./scripts/run_domain_expanded_pretrain.sh',
|
| 209 |
+
'--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
|
| 210 |
+
'--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
|
| 211 |
+
]
|
| 212 |
+
print('[job] ensuring retina.npz before training...', flush=True)
|
| 213 |
+
try:
|
| 214 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 215 |
+
from subsystems.sdr_retina import build_retina
|
| 216 |
+
build_retina()
|
| 217 |
+
except Exception as _retina_err:
|
| 218 |
+
print(f'[job] retina bootstrap warning (train.py may still build it): {_retina_err}', flush=True)
|
| 219 |
+
print('[job] starting Feather domain-expanded pretrain', flush=True)
|
| 220 |
+
print(f'[job] command={cmd}', flush=True)
|
| 221 |
+
proc = subprocess.run(cmd, check=False)
|
| 222 |
+
|
| 223 |
+
# Push triton compilation cache back to HF Hub for next run.
|
| 224 |
+
try:
|
| 225 |
+
import triton_cache_setup as _tcs
|
| 226 |
+
_tcs.teardown()
|
| 227 |
+
except Exception as _tcs_err:
|
| 228 |
+
print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
|
| 229 |
+
|
| 230 |
+
if TOKEN:
|
| 231 |
+
api = HfApi(token=TOKEN)
|
| 232 |
+
try:
|
| 233 |
+
api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
|
| 236 |
+
prefix = f'jobs/{JOB_ID}'
|
| 237 |
+
try:
|
| 238 |
+
upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
|
| 239 |
+
upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
|
| 240 |
+
upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
|
| 243 |
+
else:
|
| 244 |
+
print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
|
| 245 |
+
|
| 246 |
+
return proc.returncode
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def run_space_mode() -> int:
|
| 250 |
+
server = _start_health_server()
|
| 251 |
+
print('[space] Feather runtime image ready', flush=True)
|
| 252 |
+
try:
|
| 253 |
+
while True:
|
| 254 |
+
time.sleep(3600)
|
| 255 |
+
finally:
|
| 256 |
+
server.shutdown()
|
| 257 |
+
server.server_close()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main() -> int:
|
| 261 |
+
if RUNTIME_MODE == 'job':
|
| 262 |
+
return run_job_mode()
|
| 263 |
+
return run_space_mode()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == '__main__':
|
| 267 |
+
raise SystemExit(main())
|
mamba_ssm_init.py
CHANGED
|
@@ -1,101 +1,69 @@
|
|
| 1 |
-
# mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
|
| 2 |
-
# ABI mismatch with the base image's libtorch.
|
| 3 |
-
#
|
| 4 |
-
# The upstream __init__.py eagerly imports selective_scan_cuda which fails on
|
| 5 |
-
# pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
|
| 6 |
-
# symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
|
| 7 |
-
# all compiled-CUDA imports here and let Mamba3 load directly.
|
| 8 |
-
|
| 9 |
-
__version__ = "2.3.1+feather-graft"
|
| 10 |
-
|
| 11 |
-
# selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
|
| 12 |
-
# by the Feather training path (which is Mamba3-only). If any import path
|
| 13 |
-
# hits this, it will get a clear AttributeError instead of an obscure ImportError.
|
| 14 |
-
selective_scan_fn = None
|
| 15 |
-
mamba_inner_fn = None
|
| 16 |
-
|
| 17 |
-
# --- triton API compatibility shims -----------------------------------------
|
| 18 |
-
# Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
|
| 19 |
-
# imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
|
| 20 |
-
# Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
|
| 21 |
-
# tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
|
| 22 |
-
# satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
|
| 23 |
-
# APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
|
| 24 |
-
# stub is never actually invoked at runtime because the codebase does not use
|
| 25 |
-
# torch.compile — but importing torch._inductor.* still requires the symbol to
|
| 26 |
-
# exist at module load time.
|
| 27 |
import triton as _triton # noqa: E402
|
| 28 |
if not hasattr(_triton, "set_allocator"):
|
| 29 |
-
def _noop_set_allocator(_fn): # pragma: no cover
|
| 30 |
-
return None
|
| 31 |
-
_triton.set_allocator = _noop_set_allocator
|
| 32 |
-
|
| 33 |
-
import triton.compiler.compiler as _tcc # noqa: E402
|
| 34 |
-
if not hasattr(_tcc, "AttrsDescriptor"):
|
| 35 |
-
class _AttrsDescriptorShim:
|
| 36 |
-
"""Stub for torch._inductor compatibility on triton >= 3.4.
|
| 37 |
-
torch._inductor.runtime.hints imports this at module load but the
|
| 38 |
-
constructor is only called inside torch.compile paths. Accept any
|
| 39 |
-
args/kwargs so the import itself succeeds."""
|
| 40 |
-
def __init__(self, *args, **kwargs):
|
| 41 |
-
self.args = args
|
| 42 |
-
self.kwargs = kwargs
|
| 43 |
-
|
| 44 |
-
@classmethod
|
| 45 |
-
def from_hints(cls, *args, **kwargs):
|
| 46 |
-
return cls(*args, **kwargs)
|
| 47 |
-
|
| 48 |
-
_tcc.AttrsDescriptor = _AttrsDescriptorShim
|
| 49 |
-
|
| 50 |
-
# triton_key: removed in triton 3.5, used by torch._inductor.codecache for
|
| 51 |
-
# FxGraphCache key derivation. Return a stable string so caching still works.
|
| 52 |
-
if not hasattr(_tcc, "triton_key"):
|
| 53 |
-
def _triton_key_shim():
|
| 54 |
-
import triton as _t
|
| 55 |
-
return f"triton-{_t.__version__}-shim"
|
| 56 |
-
_tcc.triton_key = _triton_key_shim
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# Patch _create_driver to directly select CudaDriver when registry discovery
|
| 62 |
-
# returns empty.
|
| 63 |
-
import importlib as _importlib # noqa: E402
|
| 64 |
-
_triton_driver_mod = _importlib.import_module("triton.runtime.driver")
|
| 65 |
-
if getattr(_triton_driver_mod, "backends", None) == {}:
|
| 66 |
-
from triton.backends.nvidia import driver as _nvidia_driver # noqa: E402
|
| 67 |
|
| 68 |
-
|
| 69 |
-
if hasattr(_nvidia_driver, "CudaDriver") and _nvidia_driver.CudaDriver.is_active():
|
| 70 |
-
return _nvidia_driver.CudaDriver()
|
| 71 |
-
raise RuntimeError(
|
| 72 |
-
"Triton backend registry is empty and NVIDIA CudaDriver is not active"
|
| 73 |
-
)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
driver=_CudaDriver,
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
# Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
|
| 91 |
-
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
|
| 92 |
-
# so fall back to eager on any dynamo failure rather than crashing. This is
|
| 93 |
-
# defense-in-depth against further triton API drift.
|
| 94 |
-
try:
|
| 95 |
-
import torch._dynamo # noqa: F401 — triggers dynamo module init
|
| 96 |
-
torch._dynamo.config.suppress_errors = True
|
| 97 |
-
except Exception: # pragma: no cover
|
| 98 |
-
pass
|
| 99 |
-
|
| 100 |
-
# Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
|
| 101 |
-
from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
|
|
|
|
| 1 |
+
# mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
|
| 2 |
+
# ABI mismatch with the base image's libtorch.
|
| 3 |
+
#
|
| 4 |
+
# The upstream __init__.py eagerly imports selective_scan_cuda which fails on
|
| 5 |
+
# pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
|
| 6 |
+
# symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
|
| 7 |
+
# all compiled-CUDA imports here and let Mamba3 load directly.
|
| 8 |
+
|
| 9 |
+
__version__ = "2.3.1+feather-graft"
|
| 10 |
+
|
| 11 |
+
# selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
|
| 12 |
+
# by the Feather training path (which is Mamba3-only). If any import path
|
| 13 |
+
# hits this, it will get a clear AttributeError instead of an obscure ImportError.
|
| 14 |
+
selective_scan_fn = None
|
| 15 |
+
mamba_inner_fn = None
|
| 16 |
+
|
| 17 |
+
# --- triton API compatibility shims -----------------------------------------
|
| 18 |
+
# Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
|
| 19 |
+
# imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
|
| 20 |
+
# Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
|
| 21 |
+
# tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
|
| 22 |
+
# satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
|
| 23 |
+
# APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
|
| 24 |
+
# stub is never actually invoked at runtime because the codebase does not use
|
| 25 |
+
# torch.compile — but importing torch._inductor.* still requires the symbol to
|
| 26 |
+
# exist at module load time.
|
| 27 |
import triton as _triton # noqa: E402
|
| 28 |
if not hasattr(_triton, "set_allocator"):
|
| 29 |
+
def _noop_set_allocator(_fn): # pragma: no cover
|
| 30 |
+
return None
|
| 31 |
+
_triton.set_allocator = _noop_set_allocator
|
| 32 |
+
|
| 33 |
+
import triton.compiler.compiler as _tcc # noqa: E402
|
| 34 |
+
if not hasattr(_tcc, "AttrsDescriptor"):
|
| 35 |
+
class _AttrsDescriptorShim:
|
| 36 |
+
"""Stub for torch._inductor compatibility on triton >= 3.4.
|
| 37 |
+
torch._inductor.runtime.hints imports this at module load but the
|
| 38 |
+
constructor is only called inside torch.compile paths. Accept any
|
| 39 |
+
args/kwargs so the import itself succeeds."""
|
| 40 |
+
def __init__(self, *args, **kwargs):
|
| 41 |
+
self.args = args
|
| 42 |
+
self.kwargs = kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
@classmethod
|
| 45 |
+
def from_hints(cls, *args, **kwargs):
|
| 46 |
+
return cls(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
_tcc.AttrsDescriptor = _AttrsDescriptorShim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# triton_key: removed in triton 3.5, used by torch._inductor.codecache for
|
| 51 |
+
# FxGraphCache key derivation. Return a stable string so caching still works.
|
| 52 |
+
if not hasattr(_tcc, "triton_key"):
|
| 53 |
+
def _triton_key_shim():
|
| 54 |
+
import triton as _t
|
| 55 |
+
return f"triton-{_t.__version__}-shim"
|
| 56 |
+
_tcc.triton_key = _triton_key_shim
|
| 57 |
|
| 58 |
+
# Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
|
| 59 |
+
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
|
| 60 |
+
# so fall back to eager on any dynamo failure rather than crashing. This is
|
| 61 |
+
# defense-in-depth against further triton API drift.
|
| 62 |
+
try:
|
| 63 |
+
import torch._dynamo # noqa: F401 — triggers dynamo module init
|
| 64 |
+
torch._dynamo.config.suppress_errors = True
|
| 65 |
+
except Exception: # pragma: no cover
|
| 66 |
+
pass
|
| 67 |
|
| 68 |
+
# Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
|
| 69 |
+
from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/.dockerignore
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
-
.git
|
| 2 |
-
.github
|
| 3 |
-
.venv
|
| 4 |
-
.remember
|
| 5 |
-
.letta
|
| 6 |
-
.claude
|
| 7 |
-
__pycache__
|
| 8 |
-
*.pyc
|
| 9 |
-
*.pyo
|
| 10 |
-
*.pyd
|
| 11 |
-
*.log
|
| 12 |
-
run_*.log
|
| 13 |
-
run*.log
|
| 14 |
-
*.txt
|
| 15 |
-
WORKER_COMPLETE
|
| 16 |
-
autoresearch_loop.log
|
| 17 |
-
data/
|
| 18 |
-
state_store/
|
| 19 |
-
htm_rust/target/
|
| 20 |
-
hydra-core/target/
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
.venv
|
| 4 |
+
.remember
|
| 5 |
+
.letta
|
| 6 |
+
.claude
|
| 7 |
+
__pycache__
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
*.log
|
| 12 |
+
run_*.log
|
| 13 |
+
run*.log
|
| 14 |
+
*.txt
|
| 15 |
+
WORKER_COMPLETE
|
| 16 |
+
autoresearch_loop.log
|
| 17 |
+
data/
|
| 18 |
+
state_store/
|
| 19 |
+
htm_rust/target/
|
| 20 |
+
hydra-core/target/
|
overlay/configs/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from configs.hardware_config import HardwareConfig
|
| 2 |
-
from configs.harness_config import HarnessConfig
|
| 3 |
-
from configs.model_config import PostSemClawConfig
|
| 4 |
-
|
| 5 |
-
__all__ = ["PostSemClawConfig", "HarnessConfig", "HardwareConfig"]
|
|
|
|
| 1 |
+
from configs.hardware_config import HardwareConfig
|
| 2 |
+
from configs.harness_config import HarnessConfig
|
| 3 |
+
from configs.model_config import PostSemClawConfig
|
| 4 |
+
|
| 5 |
+
__all__ = ["PostSemClawConfig", "HarnessConfig", "HardwareConfig"]
|
overlay/configs/hardware_config.py
CHANGED
|
@@ -1,104 +1,104 @@
|
|
| 1 |
-
"""Hardware detection and memory budget configuration."""
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from pydantic import BaseModel, Field
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class HardwareConfig(BaseModel):
|
| 9 |
-
"""Auto-detected hardware configuration with memory budgets."""
|
| 10 |
-
|
| 11 |
-
gpu_name: str = Field(default="unknown", description="GPU device name")
|
| 12 |
-
gpu_memory_mb: int = Field(default=0, description="Total GPU memory in MB")
|
| 13 |
-
gpu_vram_mb: int = Field(default=0, description="Alias for gpu_memory_mb (legacy compat)")
|
| 14 |
-
compute_capability: tuple[int, int] = Field(
|
| 15 |
-
default=(0, 0), description="CUDA compute capability"
|
| 16 |
-
)
|
| 17 |
-
peak_flops: float = Field(
|
| 18 |
-
default=12.74e12, description="Peak FP32 FLOPS for MFU calculation"
|
| 19 |
-
)
|
| 20 |
-
bf16_peak_flops: float = Field(
|
| 21 |
-
default=38.1e12, description="Peak BF16 FLOPS (RTX 3060 default)"
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
# Memory budget
|
| 25 |
-
model_budget_mb: int = Field(
|
| 26 |
-
default=1500, description="Max MB for model params + optimizer"
|
| 27 |
-
)
|
| 28 |
-
activation_budget_mb: int = Field(
|
| 29 |
-
default=3000, description="Max MB for activations"
|
| 30 |
-
)
|
| 31 |
-
overhead_mb: int = Field(
|
| 32 |
-
default=500, description="Reserved for CUDA context + PyTorch overhead"
|
| 33 |
-
)
|
| 34 |
-
max_vram_usage_pct: float = Field(
|
| 35 |
-
default=90.0, description="Max VRAM usage as % of total"
|
| 36 |
-
)
|
| 37 |
-
gradient_checkpointing: bool = Field(
|
| 38 |
-
default=False, description="Enable gradient checkpointing to save VRAM"
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
@classmethod
|
| 42 |
-
def detect(cls) -> HardwareConfig:
|
| 43 |
-
"""Auto-detect hardware from current CUDA device."""
|
| 44 |
-
if not torch.cuda.is_available():
|
| 45 |
-
return cls()
|
| 46 |
-
|
| 47 |
-
device = torch.cuda.current_device()
|
| 48 |
-
props = torch.cuda.get_device_properties(device)
|
| 49 |
-
cap = (props.major, props.minor)
|
| 50 |
-
mem_mb = props.total_memory // (1024 * 1024)
|
| 51 |
-
gpu_name = props.name
|
| 52 |
-
|
| 53 |
-
# Peak FP32 FLOPS lookup by compute capability (approximate)
|
| 54 |
-
fp32_flops_table: dict[tuple[int, int], float] = {
|
| 55 |
-
(8, 6): 12.74e12, # RTX 3060
|
| 56 |
-
(8, 9): 40.09e12, # RTX 4090
|
| 57 |
-
(9, 0): 989.5e12, # H100 (BF16)
|
| 58 |
-
}
|
| 59 |
-
peak = fp32_flops_table.get(cap, 12.74e12)
|
| 60 |
-
|
| 61 |
-
# BF16 peak FLOPS lookup by GPU name substring
|
| 62 |
-
bf16_flops_table: dict[str, float] = {
|
| 63 |
-
"3060": 38.1e12,
|
| 64 |
-
"3090": 71.0e12,
|
| 65 |
-
"4090": 165.2e12,
|
| 66 |
-
"A100": 312e12,
|
| 67 |
-
"H100": 989.5e12,
|
| 68 |
-
"A10G": 70.0e12,
|
| 69 |
-
}
|
| 70 |
-
bf16_peak = 38.1e12 # default to RTX 3060
|
| 71 |
-
for key, val in bf16_flops_table.items():
|
| 72 |
-
if key in gpu_name:
|
| 73 |
-
bf16_peak = val
|
| 74 |
-
break
|
| 75 |
-
|
| 76 |
-
# Memory budget: leave overhead_mb for CUDA context
|
| 77 |
-
overhead = 500
|
| 78 |
-
available = mem_mb - overhead
|
| 79 |
-
model_budget = int(available * 0.3) # 30% for params + optimizer
|
| 80 |
-
activation_budget = int(available * 0.7) # 70% for activations
|
| 81 |
-
|
| 82 |
-
return cls(
|
| 83 |
-
gpu_name=gpu_name,
|
| 84 |
-
gpu_memory_mb=mem_mb,
|
| 85 |
-
gpu_vram_mb=mem_mb,
|
| 86 |
-
compute_capability=cap,
|
| 87 |
-
peak_flops=peak,
|
| 88 |
-
bf16_peak_flops=bf16_peak,
|
| 89 |
-
model_budget_mb=model_budget,
|
| 90 |
-
activation_budget_mb=activation_budget,
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
def suggest_batch_size(self, d_model: int, seq_len: int, n_layer: int) -> int:
|
| 94 |
-
"""Suggest batch size based on activation budget.
|
| 95 |
-
|
| 96 |
-
Uses rough estimate: per-sample activation ~= n_layer * seq_len * d_model
|
| 97 |
-
* 4 bytes * 2 (fwd + bwd).
|
| 98 |
-
"""
|
| 99 |
-
per_sample_mb = n_layer * seq_len * d_model * 4 * 2 / (1024 * 1024)
|
| 100 |
-
if per_sample_mb <= 0:
|
| 101 |
-
return 1
|
| 102 |
-
batch = max(1, int(self.activation_budget_mb / per_sample_mb))
|
| 103 |
-
# Round down to power of 2
|
| 104 |
-
return 2 ** (batch.bit_length() - 1) if batch > 1 else 1
|
|
|
|
| 1 |
+
"""Hardware detection and memory budget configuration."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class HardwareConfig(BaseModel):
|
| 9 |
+
"""Auto-detected hardware configuration with memory budgets."""
|
| 10 |
+
|
| 11 |
+
gpu_name: str = Field(default="unknown", description="GPU device name")
|
| 12 |
+
gpu_memory_mb: int = Field(default=0, description="Total GPU memory in MB")
|
| 13 |
+
gpu_vram_mb: int = Field(default=0, description="Alias for gpu_memory_mb (legacy compat)")
|
| 14 |
+
compute_capability: tuple[int, int] = Field(
|
| 15 |
+
default=(0, 0), description="CUDA compute capability"
|
| 16 |
+
)
|
| 17 |
+
peak_flops: float = Field(
|
| 18 |
+
default=12.74e12, description="Peak FP32 FLOPS for MFU calculation"
|
| 19 |
+
)
|
| 20 |
+
bf16_peak_flops: float = Field(
|
| 21 |
+
default=38.1e12, description="Peak BF16 FLOPS (RTX 3060 default)"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Memory budget
|
| 25 |
+
model_budget_mb: int = Field(
|
| 26 |
+
default=1500, description="Max MB for model params + optimizer"
|
| 27 |
+
)
|
| 28 |
+
activation_budget_mb: int = Field(
|
| 29 |
+
default=3000, description="Max MB for activations"
|
| 30 |
+
)
|
| 31 |
+
overhead_mb: int = Field(
|
| 32 |
+
default=500, description="Reserved for CUDA context + PyTorch overhead"
|
| 33 |
+
)
|
| 34 |
+
max_vram_usage_pct: float = Field(
|
| 35 |
+
default=90.0, description="Max VRAM usage as % of total"
|
| 36 |
+
)
|
| 37 |
+
gradient_checkpointing: bool = Field(
|
| 38 |
+
default=False, description="Enable gradient checkpointing to save VRAM"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def detect(cls) -> HardwareConfig:
|
| 43 |
+
"""Auto-detect hardware from current CUDA device."""
|
| 44 |
+
if not torch.cuda.is_available():
|
| 45 |
+
return cls()
|
| 46 |
+
|
| 47 |
+
device = torch.cuda.current_device()
|
| 48 |
+
props = torch.cuda.get_device_properties(device)
|
| 49 |
+
cap = (props.major, props.minor)
|
| 50 |
+
mem_mb = props.total_memory // (1024 * 1024)
|
| 51 |
+
gpu_name = props.name
|
| 52 |
+
|
| 53 |
+
# Peak FP32 FLOPS lookup by compute capability (approximate)
|
| 54 |
+
fp32_flops_table: dict[tuple[int, int], float] = {
|
| 55 |
+
(8, 6): 12.74e12, # RTX 3060
|
| 56 |
+
(8, 9): 40.09e12, # RTX 4090
|
| 57 |
+
(9, 0): 989.5e12, # H100 (BF16)
|
| 58 |
+
}
|
| 59 |
+
peak = fp32_flops_table.get(cap, 12.74e12)
|
| 60 |
+
|
| 61 |
+
# BF16 peak FLOPS lookup by GPU name substring
|
| 62 |
+
bf16_flops_table: dict[str, float] = {
|
| 63 |
+
"3060": 38.1e12,
|
| 64 |
+
"3090": 71.0e12,
|
| 65 |
+
"4090": 165.2e12,
|
| 66 |
+
"A100": 312e12,
|
| 67 |
+
"H100": 989.5e12,
|
| 68 |
+
"A10G": 70.0e12,
|
| 69 |
+
}
|
| 70 |
+
bf16_peak = 38.1e12 # default to RTX 3060
|
| 71 |
+
for key, val in bf16_flops_table.items():
|
| 72 |
+
if key in gpu_name:
|
| 73 |
+
bf16_peak = val
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
# Memory budget: leave overhead_mb for CUDA context
|
| 77 |
+
overhead = 500
|
| 78 |
+
available = mem_mb - overhead
|
| 79 |
+
model_budget = int(available * 0.3) # 30% for params + optimizer
|
| 80 |
+
activation_budget = int(available * 0.7) # 70% for activations
|
| 81 |
+
|
| 82 |
+
return cls(
|
| 83 |
+
gpu_name=gpu_name,
|
| 84 |
+
gpu_memory_mb=mem_mb,
|
| 85 |
+
gpu_vram_mb=mem_mb,
|
| 86 |
+
compute_capability=cap,
|
| 87 |
+
peak_flops=peak,
|
| 88 |
+
bf16_peak_flops=bf16_peak,
|
| 89 |
+
model_budget_mb=model_budget,
|
| 90 |
+
activation_budget_mb=activation_budget,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def suggest_batch_size(self, d_model: int, seq_len: int, n_layer: int) -> int:
|
| 94 |
+
"""Suggest batch size based on activation budget.
|
| 95 |
+
|
| 96 |
+
Uses rough estimate: per-sample activation ~= n_layer * seq_len * d_model
|
| 97 |
+
* 4 bytes * 2 (fwd + bwd).
|
| 98 |
+
"""
|
| 99 |
+
per_sample_mb = n_layer * seq_len * d_model * 4 * 2 / (1024 * 1024)
|
| 100 |
+
if per_sample_mb <= 0:
|
| 101 |
+
return 1
|
| 102 |
+
batch = max(1, int(self.activation_budget_mb / per_sample_mb))
|
| 103 |
+
# Round down to power of 2
|
| 104 |
+
return 2 ** (batch.bit_length() - 1) if batch > 1 else 1
|
overlay/configs/harness_config.py
CHANGED
|
@@ -3,53 +3,53 @@ from typing import Literal
|
|
| 3 |
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
class HarnessConfig(BaseModel):
|
| 11 |
-
"""Configuration for the HYDRA harness behavior."""
|
| 12 |
-
|
| 13 |
-
# Inner loop
|
| 14 |
-
time_budget_seconds: int = Field(
|
| 15 |
-
default=300, ge=60, description="Training time budget per experiment in seconds"
|
| 16 |
-
)
|
| 17 |
-
max_experiments: int = Field(
|
| 18 |
-
default=1000, ge=0, description="Max experiments before stopping (0=infinite)"
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
# Meta-agent
|
| 22 |
-
meta_interval: int = Field(
|
| 23 |
-
default=20, ge=5, description="Run meta-agent every N experiments"
|
| 24 |
-
)
|
| 25 |
-
max_meta_changes: int = Field(
|
| 26 |
-
default=3, ge=1, le=10, description="Max changes per meta-iteration"
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Search strategy
|
| 30 |
-
exploration_mode: Literal["conservative", "balanced", "bold"] = "balanced"
|
| 31 |
-
exploration_budget: int = Field(
|
| 32 |
-
default=5, ge=1, description="Consecutive bold experiments when stuck"
|
| 33 |
-
)
|
| 34 |
-
stuck_threshold: int = Field(
|
| 35 |
-
default=10, ge=3, description="No improvement for N experiments = stuck"
|
| 36 |
-
)
|
| 37 |
-
crash_threshold: float = Field(
|
| 38 |
-
default=0.5,
|
| 39 |
-
ge=0.1,
|
| 40 |
-
le=1.0,
|
| 41 |
-
description="Crash rate threshold for BROKEN state",
|
| 42 |
-
)
|
| 43 |
-
regression_tolerance: float = Field(
|
| 44 |
-
default=0.05,
|
| 45 |
-
ge=0,
|
| 46 |
-
le=0.2,
|
| 47 |
-
description="Max val_bpb regression from best (fraction)",
|
| 48 |
-
)
|
| 49 |
-
max_regression_pct: float = Field(
|
| 50 |
-
default=5.0, description="Max % regression from best known val_bpb"
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
# Keep/discard criteria
|
| 54 |
primary_metric: str = "val_bpb"
|
| 55 |
secondary_metrics: GateConfig = Field(
|
|
@@ -63,23 +63,23 @@ class HarnessConfig(BaseModel):
|
|
| 63 |
"hestia_quant_error": {"max": 0.05},
|
| 64 |
}
|
| 65 |
)
|
| 66 |
-
|
| 67 |
-
# Experiment execution
|
| 68 |
-
experiment_timeout: int = Field(
|
| 69 |
-
default=600, ge=300, description="Kill experiment after N seconds"
|
| 70 |
-
)
|
| 71 |
-
warmup_steps: int = Field(
|
| 72 |
-
default=10, ge=0, description="Steps to exclude from timing"
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
# Git
|
| 76 |
-
branch_prefix: str = Field(default="autoresearch", description="Branch naming prefix")
|
| 77 |
-
results_file: str = Field(default="results.tsv", description="Experiment log file")
|
| 78 |
-
|
| 79 |
-
# Secondary metric gates (optional keep/discard criteria)
|
| 80 |
-
gate_mhc_spectral_norm: float | None = Field(
|
| 81 |
-
default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
|
| 82 |
-
)
|
| 83 |
gate_engram_hit_rate: float | None = Field(
|
| 84 |
default=None, description="Min engram_hit_rate for keep (None=disabled)"
|
| 85 |
)
|
|
|
|
| 3 |
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
|
| 6 |
+
GateThresholds = dict[str, float]
|
| 7 |
+
GateConfig = dict[str, GateThresholds]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
class HarnessConfig(BaseModel):
|
| 11 |
+
"""Configuration for the HYDRA harness behavior."""
|
| 12 |
+
|
| 13 |
+
# Inner loop
|
| 14 |
+
time_budget_seconds: int = Field(
|
| 15 |
+
default=300, ge=60, description="Training time budget per experiment in seconds"
|
| 16 |
+
)
|
| 17 |
+
max_experiments: int = Field(
|
| 18 |
+
default=1000, ge=0, description="Max experiments before stopping (0=infinite)"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Meta-agent
|
| 22 |
+
meta_interval: int = Field(
|
| 23 |
+
default=20, ge=5, description="Run meta-agent every N experiments"
|
| 24 |
+
)
|
| 25 |
+
max_meta_changes: int = Field(
|
| 26 |
+
default=3, ge=1, le=10, description="Max changes per meta-iteration"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Search strategy
|
| 30 |
+
exploration_mode: Literal["conservative", "balanced", "bold"] = "balanced"
|
| 31 |
+
exploration_budget: int = Field(
|
| 32 |
+
default=5, ge=1, description="Consecutive bold experiments when stuck"
|
| 33 |
+
)
|
| 34 |
+
stuck_threshold: int = Field(
|
| 35 |
+
default=10, ge=3, description="No improvement for N experiments = stuck"
|
| 36 |
+
)
|
| 37 |
+
crash_threshold: float = Field(
|
| 38 |
+
default=0.5,
|
| 39 |
+
ge=0.1,
|
| 40 |
+
le=1.0,
|
| 41 |
+
description="Crash rate threshold for BROKEN state",
|
| 42 |
+
)
|
| 43 |
+
regression_tolerance: float = Field(
|
| 44 |
+
default=0.05,
|
| 45 |
+
ge=0,
|
| 46 |
+
le=0.2,
|
| 47 |
+
description="Max val_bpb regression from best (fraction)",
|
| 48 |
+
)
|
| 49 |
+
max_regression_pct: float = Field(
|
| 50 |
+
default=5.0, description="Max % regression from best known val_bpb"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
# Keep/discard criteria
|
| 54 |
primary_metric: str = "val_bpb"
|
| 55 |
secondary_metrics: GateConfig = Field(
|
|
|
|
| 63 |
"hestia_quant_error": {"max": 0.05},
|
| 64 |
}
|
| 65 |
)
|
| 66 |
+
|
| 67 |
+
# Experiment execution
|
| 68 |
+
experiment_timeout: int = Field(
|
| 69 |
+
default=600, ge=300, description="Kill experiment after N seconds"
|
| 70 |
+
)
|
| 71 |
+
warmup_steps: int = Field(
|
| 72 |
+
default=10, ge=0, description="Steps to exclude from timing"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Git
|
| 76 |
+
branch_prefix: str = Field(default="autoresearch", description="Branch naming prefix")
|
| 77 |
+
results_file: str = Field(default="results.tsv", description="Experiment log file")
|
| 78 |
+
|
| 79 |
+
# Secondary metric gates (optional keep/discard criteria)
|
| 80 |
+
gate_mhc_spectral_norm: float | None = Field(
|
| 81 |
+
default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
|
| 82 |
+
)
|
| 83 |
gate_engram_hit_rate: float | None = Field(
|
| 84 |
default=None, description="Min engram_hit_rate for keep (None=disabled)"
|
| 85 |
)
|
overlay/configs/model_config.py
CHANGED
|
@@ -1,80 +1,80 @@
|
|
| 1 |
-
"""Post-SEM-Claw model configuration with Pydantic validation."""
|
| 2 |
-
from pydantic import BaseModel, Field, field_validator
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class PostSemClawConfig(BaseModel):
|
| 6 |
-
"""Configuration for the Post-SEM-Claw architecture.
|
| 7 |
-
|
| 8 |
-
Default values mirror the @dataclass in train.py exactly.
|
| 9 |
-
train.py is the source of truth — this file must stay in sync with it.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
# Sequence
|
| 13 |
-
sequence_len: int = Field(default=2048, description="Context length (from prepare.py MAX_SEQ_LEN)")
|
| 14 |
-
vocab_size: int = Field(default=8192, description="Vocabulary size (from prepare.py VOCAB_SIZE)")
|
| 15 |
-
|
| 16 |
-
# Mamba-3 SSM
|
| 17 |
-
n_layer: int = Field(default=4, ge=1, le=48, description="Number of Mamba-3 blocks")
|
| 18 |
-
d_model: int = Field(default=256, ge=64, description="Model embedding dimension")
|
| 19 |
-
d_state: int = Field(default=64, ge=16, description="SSM state dimension")
|
| 20 |
-
headdim: int = Field(default=32, ge=16, description="SSM head dimension")
|
| 21 |
-
n_heads: int = Field(default=8, ge=1, description="Number of SSM heads (d_model // headdim)")
|
| 22 |
-
expand: int = Field(default=2, ge=1, le=4, description="Inner dim multiplier (inner_dim = expand * d_model)")
|
| 23 |
-
|
| 24 |
-
# mHC (Manifold Hyper-Connection)
|
| 25 |
-
mhc_n_streams: int = Field(default=4, ge=2, le=8, description="Number of residual streams")
|
| 26 |
-
mhc_sinkhorn_iters: int = Field(default=5, ge=1, le=100, description="Sinkhorn-Knopp iterations")
|
| 27 |
-
|
| 28 |
-
# Engram (conditional memory)
|
| 29 |
-
engram_n_columns: int = Field(default=4096, ge=256, description="Hash table columns")
|
| 30 |
-
engram_key_dim: int = Field(default=64, ge=16, description="Engram key dimension")
|
| 31 |
-
engram_layer_idx: int = Field(default=1, ge=0, description="Which layer gets engram (0-indexed)")
|
| 32 |
-
|
| 33 |
-
# Hestia QAT (disabled Phase 1, skeleton only)
|
| 34 |
-
hestia_enabled: bool = Field(default=False, description="Enable Hestia quantization")
|
| 35 |
-
hestia_bits: float = Field(default=1.58, gt=0, description="Target quantization bits (1.58 = 1.58-bit ternary)")
|
| 36 |
-
|
| 37 |
-
# SDR (bypass-only in Phase 1)
|
| 38 |
-
sdr_enabled: bool = Field(default=False, description="Enable stochastic resonance")
|
| 39 |
-
sdr_k: int = Field(default=64, ge=1, description="Top-K sparsification")
|
| 40 |
-
sdr_noise_std: float = Field(default=0.1, ge=0.0, description="SR noise standard deviation")
|
| 41 |
-
|
| 42 |
-
@field_validator("n_heads")
|
| 43 |
-
@classmethod
|
| 44 |
-
def validate_heads(cls, v: int, info: "FieldValidationInfo") -> int:
|
| 45 |
-
"""Ensure n_heads equals d_model // headdim."""
|
| 46 |
-
d_model = info.data.get("d_model", 256)
|
| 47 |
-
headdim = info.data.get("headdim", 32)
|
| 48 |
-
expected = d_model // headdim
|
| 49 |
-
if v != expected:
|
| 50 |
-
raise ValueError(
|
| 51 |
-
f"n_heads ({v}) must equal d_model // headdim ({expected})"
|
| 52 |
-
)
|
| 53 |
-
return v
|
| 54 |
-
|
| 55 |
-
def estimate_params(self) -> int:
|
| 56 |
-
"""Rough parameter count estimate based on train.py architecture."""
|
| 57 |
-
inner = self.expand * self.d_model
|
| 58 |
-
# in_proj: d_model -> inner + inner + d_state + d_state + n_heads
|
| 59 |
-
in_proj = self.d_model * (inner + inner + self.d_state + self.d_state + self.n_heads)
|
| 60 |
-
out_proj = inner * self.d_model
|
| 61 |
-
# conv1d (kernel=4, groups=inner_dim)
|
| 62 |
-
conv = inner * 4
|
| 63 |
-
# A_log, lambda_theta, D: n_heads each (3 vectors)
|
| 64 |
-
ssm_params = self.n_heads * 3
|
| 65 |
-
# bc_norm: d_state * 2 (weight + bias)
|
| 66 |
-
bc_norm = self.d_state * 2
|
| 67 |
-
per_block = in_proj + out_proj + conv + ssm_params + bc_norm
|
| 68 |
-
blocks = per_block * self.n_layer
|
| 69 |
-
|
| 70 |
-
# Embedding + lm_head (tied or untied)
|
| 71 |
-
embed = self.vocab_size * self.d_model * 2
|
| 72 |
-
|
| 73 |
-
# Engram: one instance at engram_layer_idx
|
| 74 |
-
# columns * d_model keys + d_model * engram_key_dim projection
|
| 75 |
-
engram = self.engram_n_columns * self.d_model + self.d_model * self.engram_key_dim
|
| 76 |
-
|
| 77 |
-
# mHC mixing matrices: n_layer * mhc_n_streams^2
|
| 78 |
-
mhc = self.n_layer * self.mhc_n_streams ** 2
|
| 79 |
-
|
| 80 |
-
return embed + blocks + engram + mhc
|
|
|
|
| 1 |
+
"""Post-SEM-Claw model configuration with Pydantic validation."""
|
| 2 |
+
from pydantic import BaseModel, Field, field_validator
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PostSemClawConfig(BaseModel):
|
| 6 |
+
"""Configuration for the Post-SEM-Claw architecture.
|
| 7 |
+
|
| 8 |
+
Default values mirror the @dataclass in train.py exactly.
|
| 9 |
+
train.py is the source of truth — this file must stay in sync with it.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# Sequence
|
| 13 |
+
sequence_len: int = Field(default=2048, description="Context length (from prepare.py MAX_SEQ_LEN)")
|
| 14 |
+
vocab_size: int = Field(default=8192, description="Vocabulary size (from prepare.py VOCAB_SIZE)")
|
| 15 |
+
|
| 16 |
+
# Mamba-3 SSM
|
| 17 |
+
n_layer: int = Field(default=4, ge=1, le=48, description="Number of Mamba-3 blocks")
|
| 18 |
+
d_model: int = Field(default=256, ge=64, description="Model embedding dimension")
|
| 19 |
+
d_state: int = Field(default=64, ge=16, description="SSM state dimension")
|
| 20 |
+
headdim: int = Field(default=32, ge=16, description="SSM head dimension")
|
| 21 |
+
n_heads: int = Field(default=8, ge=1, description="Number of SSM heads (d_model // headdim)")
|
| 22 |
+
expand: int = Field(default=2, ge=1, le=4, description="Inner dim multiplier (inner_dim = expand * d_model)")
|
| 23 |
+
|
| 24 |
+
# mHC (Manifold Hyper-Connection)
|
| 25 |
+
mhc_n_streams: int = Field(default=4, ge=2, le=8, description="Number of residual streams")
|
| 26 |
+
mhc_sinkhorn_iters: int = Field(default=5, ge=1, le=100, description="Sinkhorn-Knopp iterations")
|
| 27 |
+
|
| 28 |
+
# Engram (conditional memory)
|
| 29 |
+
engram_n_columns: int = Field(default=4096, ge=256, description="Hash table columns")
|
| 30 |
+
engram_key_dim: int = Field(default=64, ge=16, description="Engram key dimension")
|
| 31 |
+
engram_layer_idx: int = Field(default=1, ge=0, description="Which layer gets engram (0-indexed)")
|
| 32 |
+
|
| 33 |
+
# Hestia QAT (disabled Phase 1, skeleton only)
|
| 34 |
+
hestia_enabled: bool = Field(default=False, description="Enable Hestia quantization")
|
| 35 |
+
hestia_bits: float = Field(default=1.58, gt=0, description="Target quantization bits (1.58 = 1.58-bit ternary)")
|
| 36 |
+
|
| 37 |
+
# SDR (bypass-only in Phase 1)
|
| 38 |
+
sdr_enabled: bool = Field(default=False, description="Enable stochastic resonance")
|
| 39 |
+
sdr_k: int = Field(default=64, ge=1, description="Top-K sparsification")
|
| 40 |
+
sdr_noise_std: float = Field(default=0.1, ge=0.0, description="SR noise standard deviation")
|
| 41 |
+
|
| 42 |
+
@field_validator("n_heads")
|
| 43 |
+
@classmethod
|
| 44 |
+
def validate_heads(cls, v: int, info: "FieldValidationInfo") -> int:
|
| 45 |
+
"""Ensure n_heads equals d_model // headdim."""
|
| 46 |
+
d_model = info.data.get("d_model", 256)
|
| 47 |
+
headdim = info.data.get("headdim", 32)
|
| 48 |
+
expected = d_model // headdim
|
| 49 |
+
if v != expected:
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"n_heads ({v}) must equal d_model // headdim ({expected})"
|
| 52 |
+
)
|
| 53 |
+
return v
|
| 54 |
+
|
| 55 |
+
def estimate_params(self) -> int:
|
| 56 |
+
"""Rough parameter count estimate based on train.py architecture."""
|
| 57 |
+
inner = self.expand * self.d_model
|
| 58 |
+
# in_proj: d_model -> inner + inner + d_state + d_state + n_heads
|
| 59 |
+
in_proj = self.d_model * (inner + inner + self.d_state + self.d_state + self.n_heads)
|
| 60 |
+
out_proj = inner * self.d_model
|
| 61 |
+
# conv1d (kernel=4, groups=inner_dim)
|
| 62 |
+
conv = inner * 4
|
| 63 |
+
# A_log, lambda_theta, D: n_heads each (3 vectors)
|
| 64 |
+
ssm_params = self.n_heads * 3
|
| 65 |
+
# bc_norm: d_state * 2 (weight + bias)
|
| 66 |
+
bc_norm = self.d_state * 2
|
| 67 |
+
per_block = in_proj + out_proj + conv + ssm_params + bc_norm
|
| 68 |
+
blocks = per_block * self.n_layer
|
| 69 |
+
|
| 70 |
+
# Embedding + lm_head (tied or untied)
|
| 71 |
+
embed = self.vocab_size * self.d_model * 2
|
| 72 |
+
|
| 73 |
+
# Engram: one instance at engram_layer_idx
|
| 74 |
+
# columns * d_model keys + d_model * engram_key_dim projection
|
| 75 |
+
engram = self.engram_n_columns * self.d_model + self.d_model * self.engram_key_dim
|
| 76 |
+
|
| 77 |
+
# mHC mixing matrices: n_layer * mhc_n_streams^2
|
| 78 |
+
mhc = self.n_layer * self.mhc_n_streams ** 2
|
| 79 |
+
|
| 80 |
+
return embed + blocks + engram + mhc
|
overlay/harness/__init__.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
-
"""HYDRA harness package: orchestration infrastructure for autoresearch."""
|
| 2 |
-
from harness.eval_agent import ExperimentResult, parse_run_log, should_keep
|
| 3 |
-
from harness.git_utils import current_branch, current_commit_short
|
| 4 |
-
from harness.health_monitor import check_health, get_gpu_stats
|
| 5 |
-
from harness.meta_agent import run_meta_iteration
|
| 6 |
-
from harness.orchestrator import run_loop
|
| 7 |
-
from harness.search_strategy import ResearchState, diagnose
|
| 8 |
-
|
| 9 |
-
__all__ = [
|
| 10 |
-
"run_loop",
|
| 11 |
-
"parse_run_log",
|
| 12 |
-
"ExperimentResult",
|
| 13 |
-
"should_keep",
|
| 14 |
-
"run_meta_iteration",
|
| 15 |
-
"diagnose",
|
| 16 |
-
"ResearchState",
|
| 17 |
-
"check_health",
|
| 18 |
-
"get_gpu_stats",
|
| 19 |
-
"current_branch",
|
| 20 |
-
"current_commit_short",
|
| 21 |
-
]
|
|
|
|
| 1 |
+
"""HYDRA harness package: orchestration infrastructure for autoresearch."""
|
| 2 |
+
from harness.eval_agent import ExperimentResult, parse_run_log, should_keep
|
| 3 |
+
from harness.git_utils import current_branch, current_commit_short
|
| 4 |
+
from harness.health_monitor import check_health, get_gpu_stats
|
| 5 |
+
from harness.meta_agent import run_meta_iteration
|
| 6 |
+
from harness.orchestrator import run_loop
|
| 7 |
+
from harness.search_strategy import ResearchState, diagnose
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"run_loop",
|
| 11 |
+
"parse_run_log",
|
| 12 |
+
"ExperimentResult",
|
| 13 |
+
"should_keep",
|
| 14 |
+
"run_meta_iteration",
|
| 15 |
+
"diagnose",
|
| 16 |
+
"ResearchState",
|
| 17 |
+
"check_health",
|
| 18 |
+
"get_gpu_stats",
|
| 19 |
+
"current_branch",
|
| 20 |
+
"current_commit_short",
|
| 21 |
+
]
|
overlay/harness/eval_agent.py
CHANGED
|
@@ -1,300 +1,172 @@
|
|
| 1 |
"""Eval agent: parse run.log and extract metrics from training runs."""
|
| 2 |
import re
|
| 3 |
-
import
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
|
| 6 |
|
| 7 |
-
|
| 8 |
-
type GateConfig = dict[str, GateThresholds]
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
class ExperimentResult:
|
| 13 |
-
"""Parsed result from a single experiment run.
|
| 14 |
-
|
| 15 |
-
All float fields default to 0.0; integer fields default to 0.
|
| 16 |
-
The ``crashed`` flag is set when the log indicates a failure or the
|
| 17 |
-
log file is missing entirely.
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
# Primary metric
|
| 21 |
-
val_bpb: float = 0.0
|
| 22 |
-
|
| 23 |
-
# Timing
|
| 24 |
-
training_seconds: float = 0.0
|
| 25 |
-
total_seconds: float = 0.0
|
| 26 |
-
|
| 27 |
-
# Hardware
|
| 28 |
-
peak_vram_mb: float = 0.0
|
| 29 |
-
mfu_percent: float = 0.0
|
| 30 |
-
|
| 31 |
# Throughput
|
| 32 |
total_tokens_m: float = 0.0
|
| 33 |
num_steps: int = 0
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# Model shape (echoed by train.py summary block)
|
| 41 |
-
num_params_m: float = 0.0
|
| 42 |
-
n_layer: int = 0
|
| 43 |
-
d_model: int = 0
|
| 44 |
-
|
| 45 |
# Secondary health metrics
|
| 46 |
mhc_spectral_norm: float = 0.0
|
| 47 |
engram_hit_rate: float = 0.0
|
| 48 |
sr_bypass_rate: float = 0.0
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# Regex patterns keyed by ExperimentResult attribute name.
|
| 70 |
-
# Format must match the ``--- Summary ---`` block printed by train.py.
|
| 71 |
-
_PATTERNS: dict[str, str] = {
|
| 72 |
-
"val_bpb": r"^val_bpb:\s+([\d.]+)",
|
| 73 |
-
"training_seconds": r"^training_seconds:\s+([\d.]+)",
|
| 74 |
-
"total_seconds": r"^total_seconds:\s+([\d.]+)",
|
| 75 |
-
"peak_vram_mb": r"^peak_vram_mb:\s+([\d.]+)",
|
| 76 |
-
"mfu_percent": r"^mfu_percent:\s+([\d.]+)",
|
| 77 |
-
"total_tokens_m": r"^total_tokens_M:\s+([\d.]+)",
|
| 78 |
-
"num_steps": r"^num_steps:\s+(\d+)",
|
| 79 |
-
"num_params_m": r"^num_params_M:\s+([\d.]+)",
|
| 80 |
-
"n_layer": r"^n_layer:\s+(\d+)",
|
| 81 |
-
"d_model": r"^d_model:\s+(\d+)",
|
| 82 |
-
"mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
|
| 83 |
"engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
|
| 84 |
"sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
|
| 85 |
-
"factual_english_score": r"^factual_english_score:\s+([\d.]+)",
|
| 86 |
-
"instruction_following_score": r"^instruction_following_score:\s+([\d.]+)",
|
| 87 |
-
"distinct_1": r"^distinct_1:\s+([\d.]+)",
|
| 88 |
-
"distinct_2": r"^distinct_2:\s+([\d.]+)",
|
| 89 |
-
"repetition_rate": r"^repetition_rate:\s+([\d.]+)",
|
| 90 |
-
"repetition_bigram_rate": r"^repetition_bigram_rate:\s+([\d.]+)",
|
| 91 |
-
"calibration_ece": r"^calibration_ece:\s+([\d.]+)",
|
| 92 |
-
"calibration_brier": r"^calibration_brier:\s*([\d.]+)",
|
| 93 |
-
"calibration_accuracy": r"^calibration_accuracy:\s+([\d.]+)",
|
| 94 |
-
"calibration_tokens": r"^calibration_tokens:\s+(\d+)",
|
| 95 |
-
"eval_seed": r"^eval_seed:\s+(\d+)",
|
| 96 |
-
"eval_seed_group": r"^eval_seed_group:\s+(.+)",
|
| 97 |
}
|
| 98 |
-
|
| 99 |
-
# Attributes that should be parsed as int rather than float.
|
| 100 |
-
_INT_ATTRS: frozenset[str] = frozenset(
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def parse_run_log(log_path: str) -> ExperimentResult:
|
| 128 |
-
"""Parse a run.log file and extract all training metrics.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
log_path: Absolute path to the run.log file.
|
| 132 |
-
|
| 133 |
-
Returns:
|
| 134 |
-
Populated ExperimentResult; sets ``crashed=True`` when the log
|
| 135 |
-
contains a traceback or the file is missing.
|
| 136 |
-
"""
|
| 137 |
-
result = ExperimentResult()
|
| 138 |
-
|
| 139 |
-
try:
|
| 140 |
-
with open(log_path) as fh:
|
| 141 |
-
content = fh.read()
|
| 142 |
-
except FileNotFoundError:
|
| 143 |
-
result.crashed = True
|
| 144 |
-
result.error_message = f"Log file not found: {log_path}"
|
| 145 |
-
return result
|
| 146 |
-
|
| 147 |
-
# Detect crash signals in output. Keep this strict to avoid false positives
|
| 148 |
-
# from benign log lines that include "error" in a non-fatal context.
|
| 149 |
-
if (
|
| 150 |
-
"Traceback" in content
|
| 151 |
-
or "\nFAIL\n" in content
|
| 152 |
-
or "[TPS_GUARD] FAIL" in content
|
| 153 |
-
or "raise SystemExit(1)" in content
|
| 154 |
-
):
|
| 155 |
result.crashed = True
|
| 156 |
lines = content.strip().splitlines()
|
| 157 |
result.error_message = "\n".join(lines[-20:])
|
| 158 |
-
|
| 159 |
for attr, pattern in _PATTERNS.items():
|
| 160 |
match = re.search(pattern, content, re.MULTILINE)
|
| 161 |
if match:
|
| 162 |
raw = match.group(1)
|
| 163 |
-
if attr in _INT_ATTRS
|
| 164 |
-
setattr(result, attr, int(raw))
|
| 165 |
-
elif attr in _STR_ATTRS:
|
| 166 |
-
setattr(result, attr, raw.strip())
|
| 167 |
-
else:
|
| 168 |
-
setattr(result, attr, float(raw))
|
| 169 |
-
|
| 170 |
-
warmup_steps = 10
|
| 171 |
-
warmup_match = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", content)
|
| 172 |
-
if warmup_match:
|
| 173 |
-
warmup_steps = int(warmup_match.group(1))
|
| 174 |
-
|
| 175 |
-
step_tps_samples: list[tuple[int, int]] = []
|
| 176 |
-
for m in _STEP_TPS_PATTERN.finditer(content):
|
| 177 |
-
step_tps_samples.append((int(m.group(1)), int(m.group(2))))
|
| 178 |
-
|
| 179 |
-
tps_values: list[float] = []
|
| 180 |
-
if step_tps_samples:
|
| 181 |
-
for step, tps in step_tps_samples:
|
| 182 |
-
if step >= warmup_steps:
|
| 183 |
-
tps_values.append(float(tps))
|
| 184 |
-
if not tps_values:
|
| 185 |
-
tps_values = [float(tps) for _, tps in step_tps_samples]
|
| 186 |
-
else:
|
| 187 |
-
tps_values = [float(m.group(1)) for m in _TPS_PATTERN.finditer(content)]
|
| 188 |
-
|
| 189 |
-
if tps_values:
|
| 190 |
-
sorted_tps = sorted(tps_values)
|
| 191 |
-
result.tps_samples = len(tps_values)
|
| 192 |
-
result.tps_median = float(statistics.median(tps_values))
|
| 193 |
-
result.tps_p10 = float(_percentile_linear(sorted_tps, 10.0))
|
| 194 |
-
result.tps_min = float(sorted_tps[0])
|
| 195 |
-
result.tps_max = float(sorted_tps[-1])
|
| 196 |
|
| 197 |
return result
|
| 198 |
-
|
| 199 |
-
|
| 200 |
def check_secondary_alarms(result: ExperimentResult) -> list[str]:
|
| 201 |
-
"""Check secondary metrics against fixed alarm thresholds.
|
| 202 |
-
|
| 203 |
-
Args:
|
| 204 |
-
result: Parsed experiment result.
|
| 205 |
-
|
| 206 |
-
Returns:
|
| 207 |
-
List of human-readable alarm strings (empty if all clear).
|
| 208 |
-
"""
|
| 209 |
-
alarms: list[str] = []
|
| 210 |
-
|
| 211 |
-
if result.mhc_spectral_norm > 2.0:
|
| 212 |
-
alarms.append(
|
| 213 |
-
f"mhc_spectral_norm={result.mhc_spectral_norm:.4f} > 2.0 (ALARM)"
|
| 214 |
-
)
|
| 215 |
-
if 0 < result.engram_hit_rate < 0.1:
|
| 216 |
-
alarms.append(
|
| 217 |
-
f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
|
| 218 |
-
)
|
| 219 |
-
if 0 < result.mfu_percent < 10:
|
| 220 |
alarms.append(
|
| 221 |
-
f"
|
| 222 |
)
|
| 223 |
-
if result.
|
| 224 |
alarms.append(
|
| 225 |
-
f"
|
| 226 |
)
|
| 227 |
-
if
|
| 228 |
alarms.append(
|
| 229 |
-
f"
|
| 230 |
)
|
| 231 |
-
|
| 232 |
return alarms
|
| 233 |
|
| 234 |
|
| 235 |
-
def _check_gate(
|
| 236 |
-
result: ExperimentResult,
|
| 237 |
-
gates: GateConfig,
|
| 238 |
-
metric: str,
|
| 239 |
-
) -> tuple[bool, str] | None:
|
| 240 |
-
"""Evaluate a single min/max gate against an ExperimentResult metric."""
|
| 241 |
-
gate = gates.get(metric, {})
|
| 242 |
-
value = getattr(result, metric)
|
| 243 |
-
max_value = gate.get("max")
|
| 244 |
-
if max_value is not None and value > max_value:
|
| 245 |
-
return False, f"{metric} {value:.4f} > gate {max_value}"
|
| 246 |
-
min_value = gate.get("min")
|
| 247 |
-
if min_value is not None and value < min_value:
|
| 248 |
-
return False, f"{metric} {value:.4f} < gate {min_value}"
|
| 249 |
-
return None
|
| 250 |
-
|
| 251 |
-
|
| 252 |
def should_keep(
|
| 253 |
result: ExperimentResult,
|
| 254 |
best_bpb: float,
|
| 255 |
-
gates:
|
| 256 |
) -> tuple[bool, str]:
|
| 257 |
-
"""Decide whether to keep or discard an experiment.
|
| 258 |
-
|
| 259 |
-
The primary criterion is strictly lower val_bpb than the current best.
|
| 260 |
-
Optional secondary gates (passed from HarnessConfig.secondary_metrics)
|
| 261 |
-
can reject an otherwise-improving result.
|
| 262 |
-
|
| 263 |
-
Args:
|
| 264 |
-
result: Parsed experiment result.
|
| 265 |
-
best_bpb: Current best val_bpb across all experiments.
|
| 266 |
-
gates: Optional dict mapping metric name to threshold dict with
|
| 267 |
-
``"max"`` or ``"min"`` keys, e.g.
|
| 268 |
-
``{"mhc_spectral_norm": {"max": 2.0}}``.
|
| 269 |
-
|
| 270 |
-
Returns:
|
| 271 |
-
Tuple of (keep: bool, reason: str).
|
| 272 |
-
"""
|
| 273 |
-
if result.crashed:
|
| 274 |
-
return False, "crash"
|
| 275 |
-
if result.val_bpb <= 0:
|
| 276 |
-
return False, "invalid val_bpb"
|
| 277 |
-
if result.val_bpb >= best_bpb:
|
| 278 |
-
return False, "discard"
|
| 279 |
-
|
| 280 |
# Secondary gate checks.
|
| 281 |
if gates:
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
)
|
| 295 |
-
for metric in gate_metrics:
|
| 296 |
-
gate_result = _check_gate(result, gates, metric)
|
| 297 |
-
if gate_result is not None:
|
| 298 |
-
return gate_result
|
| 299 |
|
| 300 |
return True, "keep"
|
|
|
|
| 1 |
"""Eval agent: parse run.log and extract metrics from training runs."""
|
| 2 |
import re
|
| 3 |
+
from dataclasses import dataclass, field
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
+
@dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
class ExperimentResult:
|
| 8 |
+
"""Parsed result from a single experiment run.
|
| 9 |
+
|
| 10 |
+
All float fields default to 0.0; integer fields default to 0.
|
| 11 |
+
The ``crashed`` flag is set when the log indicates a failure or the
|
| 12 |
+
log file is missing entirely.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# Primary metric
|
| 16 |
+
val_bpb: float = 0.0
|
| 17 |
+
|
| 18 |
+
# Timing
|
| 19 |
+
training_seconds: float = 0.0
|
| 20 |
+
total_seconds: float = 0.0
|
| 21 |
+
|
| 22 |
+
# Hardware
|
| 23 |
+
peak_vram_mb: float = 0.0
|
| 24 |
+
mfu_percent: float = 0.0
|
| 25 |
+
|
| 26 |
# Throughput
|
| 27 |
total_tokens_m: float = 0.0
|
| 28 |
num_steps: int = 0
|
| 29 |
+
|
| 30 |
+
# Model shape (echoed by train.py summary block)
|
| 31 |
+
num_params_m: float = 0.0
|
| 32 |
+
n_layer: int = 0
|
| 33 |
+
d_model: int = 0
|
| 34 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Secondary health metrics
|
| 36 |
mhc_spectral_norm: float = 0.0
|
| 37 |
engram_hit_rate: float = 0.0
|
| 38 |
sr_bypass_rate: float = 0.0
|
| 39 |
|
| 40 |
+
# Status
|
| 41 |
+
crashed: bool = False
|
| 42 |
+
error_message: str = ""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Regex patterns keyed by ExperimentResult attribute name.
|
| 46 |
+
# Format must match the ``--- Summary ---`` block printed by train.py.
|
| 47 |
+
_PATTERNS: dict[str, str] = {
|
| 48 |
+
"val_bpb": r"^val_bpb:\s+([\d.]+)",
|
| 49 |
+
"training_seconds": r"^training_seconds:\s+([\d.]+)",
|
| 50 |
+
"total_seconds": r"^total_seconds:\s+([\d.]+)",
|
| 51 |
+
"peak_vram_mb": r"^peak_vram_mb:\s+([\d.]+)",
|
| 52 |
+
"mfu_percent": r"^mfu_percent:\s+([\d.]+)",
|
| 53 |
+
"total_tokens_m": r"^total_tokens_M:\s+([\d.]+)",
|
| 54 |
+
"num_steps": r"^num_steps:\s+(\d+)",
|
| 55 |
+
"num_params_m": r"^num_params_M:\s+([\d.]+)",
|
| 56 |
+
"n_layer": r"^n_layer:\s+(\d+)",
|
| 57 |
+
"d_model": r"^d_model:\s+(\d+)",
|
| 58 |
+
"mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
"engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
|
| 60 |
"sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
}
|
| 62 |
+
|
| 63 |
+
# Attributes that should be parsed as int rather than float.
|
| 64 |
+
_INT_ATTRS: frozenset[str] = frozenset({"num_steps", "n_layer", "d_model"})
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def parse_run_log(log_path: str) -> ExperimentResult:
|
| 68 |
+
"""Parse a run.log file and extract all training metrics.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
log_path: Absolute path to the run.log file.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Populated ExperimentResult; sets ``crashed=True`` when the log
|
| 75 |
+
contains a traceback or the file is missing.
|
| 76 |
+
"""
|
| 77 |
+
result = ExperimentResult()
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
with open(log_path) as fh:
|
| 81 |
+
content = fh.read()
|
| 82 |
+
except FileNotFoundError:
|
| 83 |
+
result.crashed = True
|
| 84 |
+
result.error_message = f"Log file not found: {log_path}"
|
| 85 |
+
return result
|
| 86 |
+
|
| 87 |
+
# Detect crash signals in output.
|
| 88 |
+
if "Traceback" in content or "FAIL" in content or "Error" in content:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
result.crashed = True
|
| 90 |
lines = content.strip().splitlines()
|
| 91 |
result.error_message = "\n".join(lines[-20:])
|
| 92 |
+
|
| 93 |
for attr, pattern in _PATTERNS.items():
|
| 94 |
match = re.search(pattern, content, re.MULTILINE)
|
| 95 |
if match:
|
| 96 |
raw = match.group(1)
|
| 97 |
+
setattr(result, attr, int(raw) if attr in _INT_ATTRS else float(raw))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
return result
|
| 100 |
+
|
| 101 |
+
|
| 102 |
def check_secondary_alarms(result: ExperimentResult) -> list[str]:
|
| 103 |
+
"""Check secondary metrics against fixed alarm thresholds.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
result: Parsed experiment result.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
List of human-readable alarm strings (empty if all clear).
|
| 110 |
+
"""
|
| 111 |
+
alarms: list[str] = []
|
| 112 |
+
|
| 113 |
+
if result.mhc_spectral_norm > 2.0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
alarms.append(
|
| 115 |
+
f"mhc_spectral_norm={result.mhc_spectral_norm:.4f} > 2.0 (ALARM)"
|
| 116 |
)
|
| 117 |
+
if 0 < result.engram_hit_rate < 0.1:
|
| 118 |
alarms.append(
|
| 119 |
+
f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
|
| 120 |
)
|
| 121 |
+
if 0 < result.mfu_percent < 10:
|
| 122 |
alarms.append(
|
| 123 |
+
f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
|
| 124 |
)
|
| 125 |
+
|
| 126 |
return alarms
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def should_keep(
|
| 130 |
result: ExperimentResult,
|
| 131 |
best_bpb: float,
|
| 132 |
+
gates: dict | None = None,
|
| 133 |
) -> tuple[bool, str]:
|
| 134 |
+
"""Decide whether to keep or discard an experiment.
|
| 135 |
+
|
| 136 |
+
The primary criterion is strictly lower val_bpb than the current best.
|
| 137 |
+
Optional secondary gates (passed from HarnessConfig.secondary_metrics)
|
| 138 |
+
can reject an otherwise-improving result.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
result: Parsed experiment result.
|
| 142 |
+
best_bpb: Current best val_bpb across all experiments.
|
| 143 |
+
gates: Optional dict mapping metric name to threshold dict with
|
| 144 |
+
``"max"`` or ``"min"`` keys, e.g.
|
| 145 |
+
``{"mhc_spectral_norm": {"max": 2.0}}``.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tuple of (keep: bool, reason: str).
|
| 149 |
+
"""
|
| 150 |
+
if result.crashed:
|
| 151 |
+
return False, "crash"
|
| 152 |
+
if result.val_bpb <= 0:
|
| 153 |
+
return False, "invalid val_bpb"
|
| 154 |
+
if result.val_bpb >= best_bpb:
|
| 155 |
+
return False, "discard"
|
| 156 |
+
|
| 157 |
# Secondary gate checks.
|
| 158 |
if gates:
|
| 159 |
+
gate_mhc = gates.get("mhc_spectral_norm", {}).get("max")
|
| 160 |
+
if gate_mhc is not None and result.mhc_spectral_norm > gate_mhc:
|
| 161 |
+
return (
|
| 162 |
+
False,
|
| 163 |
+
f"mhc_spectral_norm {result.mhc_spectral_norm:.4f} > gate {gate_mhc}",
|
| 164 |
+
)
|
| 165 |
+
gate_engram = gates.get("engram_hit_rate", {}).get("min")
|
| 166 |
+
if gate_engram is not None and result.engram_hit_rate < gate_engram:
|
| 167 |
+
return (
|
| 168 |
+
False,
|
| 169 |
+
f"engram_hit_rate {result.engram_hit_rate:.4f} < gate {gate_engram}",
|
| 170 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
return True, "keep"
|
overlay/harness/git_utils.py
CHANGED
|
@@ -1,94 +1,94 @@
|
|
| 1 |
-
"""Git utilities for HYDRA autoresearch branch management."""
|
| 2 |
-
import os
|
| 3 |
-
import subprocess
|
| 4 |
-
|
| 5 |
-
REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def run_git(*args: str, check: bool = True) -> subprocess.CompletedProcess:
|
| 9 |
-
"""Run a git command in the repo directory.
|
| 10 |
-
|
| 11 |
-
Args:
|
| 12 |
-
*args: Git command arguments.
|
| 13 |
-
check: Whether to raise on non-zero exit code.
|
| 14 |
-
|
| 15 |
-
Returns:
|
| 16 |
-
Completed process with stdout/stderr captured.
|
| 17 |
-
"""
|
| 18 |
-
return subprocess.run(
|
| 19 |
-
["git"] + list(args),
|
| 20 |
-
cwd=REPO_DIR,
|
| 21 |
-
capture_output=True,
|
| 22 |
-
text=True,
|
| 23 |
-
check=check,
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def current_branch() -> str:
|
| 28 |
-
"""Return the current git branch name.
|
| 29 |
-
|
| 30 |
-
Returns:
|
| 31 |
-
Branch name string.
|
| 32 |
-
"""
|
| 33 |
-
result = run_git("rev-parse", "--abbrev-ref", "HEAD")
|
| 34 |
-
return result.stdout.strip()
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def current_commit_short() -> str:
|
| 38 |
-
"""Return the current HEAD commit short hash (7 chars).
|
| 39 |
-
|
| 40 |
-
Returns:
|
| 41 |
-
7-character commit hash.
|
| 42 |
-
"""
|
| 43 |
-
result = run_git("rev-parse", "--short=7", "HEAD")
|
| 44 |
-
return result.stdout.strip()
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def create_branch(name: str) -> None:
|
| 48 |
-
"""Create and switch to a new branch.
|
| 49 |
-
|
| 50 |
-
Args:
|
| 51 |
-
name: Branch name to create.
|
| 52 |
-
"""
|
| 53 |
-
run_git("checkout", "-b", name)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def commit_all(message: str) -> str:
|
| 57 |
-
"""Stage all changes, commit, and return short hash.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
message: Commit message.
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
Short commit hash after committing.
|
| 64 |
-
"""
|
| 65 |
-
run_git("add", "-A")
|
| 66 |
-
run_git("commit", "-m", message, check=False)
|
| 67 |
-
return current_commit_short()
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def reset_to(commit: str) -> None:
|
| 71 |
-
"""Hard reset to a specific commit, discarding all changes.
|
| 72 |
-
|
| 73 |
-
Args:
|
| 74 |
-
commit: Commit hash (short or full) to reset to.
|
| 75 |
-
"""
|
| 76 |
-
run_git("reset", "--hard", commit)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_last_n_diffs(n: int = 3) -> list[str]:
|
| 80 |
-
"""Get the last N commit diffs (--stat format) for meta-agent context.
|
| 81 |
-
|
| 82 |
-
Args:
|
| 83 |
-
n: Number of recent commits to retrieve.
|
| 84 |
-
|
| 85 |
-
Returns:
|
| 86 |
-
List of diff stat strings, one per commit (truncated to 500 chars).
|
| 87 |
-
"""
|
| 88 |
-
result = run_git("log", f"-{n}", "--format=%H", check=False)
|
| 89 |
-
hashes = [h for h in result.stdout.strip().split("\n") if h]
|
| 90 |
-
diffs: list[str] = []
|
| 91 |
-
for h in hashes:
|
| 92 |
-
diff_result = run_git("show", "--stat", h, check=False)
|
| 93 |
-
diffs.append(diff_result.stdout[:500])
|
| 94 |
-
return diffs
|
|
|
|
| 1 |
+
"""Git utilities for HYDRA autoresearch branch management."""
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
REPO_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def run_git(*args: str, check: bool = True) -> subprocess.CompletedProcess:
|
| 9 |
+
"""Run a git command in the repo directory.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
*args: Git command arguments.
|
| 13 |
+
check: Whether to raise on non-zero exit code.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Completed process with stdout/stderr captured.
|
| 17 |
+
"""
|
| 18 |
+
return subprocess.run(
|
| 19 |
+
["git"] + list(args),
|
| 20 |
+
cwd=REPO_DIR,
|
| 21 |
+
capture_output=True,
|
| 22 |
+
text=True,
|
| 23 |
+
check=check,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def current_branch() -> str:
|
| 28 |
+
"""Return the current git branch name.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Branch name string.
|
| 32 |
+
"""
|
| 33 |
+
result = run_git("rev-parse", "--abbrev-ref", "HEAD")
|
| 34 |
+
return result.stdout.strip()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def current_commit_short() -> str:
|
| 38 |
+
"""Return the current HEAD commit short hash (7 chars).
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
7-character commit hash.
|
| 42 |
+
"""
|
| 43 |
+
result = run_git("rev-parse", "--short=7", "HEAD")
|
| 44 |
+
return result.stdout.strip()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def create_branch(name: str) -> None:
|
| 48 |
+
"""Create and switch to a new branch.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
name: Branch name to create.
|
| 52 |
+
"""
|
| 53 |
+
run_git("checkout", "-b", name)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def commit_all(message: str) -> str:
|
| 57 |
+
"""Stage all changes, commit, and return short hash.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
message: Commit message.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Short commit hash after committing.
|
| 64 |
+
"""
|
| 65 |
+
run_git("add", "-A")
|
| 66 |
+
run_git("commit", "-m", message, check=False)
|
| 67 |
+
return current_commit_short()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def reset_to(commit: str) -> None:
|
| 71 |
+
"""Hard reset to a specific commit, discarding all changes.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
commit: Commit hash (short or full) to reset to.
|
| 75 |
+
"""
|
| 76 |
+
run_git("reset", "--hard", commit)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_last_n_diffs(n: int = 3) -> list[str]:
|
| 80 |
+
"""Get the last N commit diffs (--stat format) for meta-agent context.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
n: Number of recent commits to retrieve.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
List of diff stat strings, one per commit (truncated to 500 chars).
|
| 87 |
+
"""
|
| 88 |
+
result = run_git("log", f"-{n}", "--format=%H", check=False)
|
| 89 |
+
hashes = [h for h in result.stdout.strip().split("\n") if h]
|
| 90 |
+
diffs: list[str] = []
|
| 91 |
+
for h in hashes:
|
| 92 |
+
diff_result = run_git("show", "--stat", h, check=False)
|
| 93 |
+
diffs.append(diff_result.stdout[:500])
|
| 94 |
+
return diffs
|
overlay/harness/health_monitor.py
CHANGED
|
@@ -1,86 +1,86 @@
|
|
| 1 |
-
"""Hardware health monitoring for HYDRA experiments.
|
| 2 |
-
|
| 3 |
-
Provides lightweight checks that the orchestrator runs before each
|
| 4 |
-
experiment to avoid launching training into a degraded GPU state.
|
| 5 |
-
"""
|
| 6 |
-
import os
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def get_gpu_stats() -> dict:
|
| 12 |
-
"""Return current GPU memory statistics.
|
| 13 |
-
|
| 14 |
-
Returns:
|
| 15 |
-
Dict with keys: available (bool), and when available:
|
| 16 |
-
name, memory_allocated_mb, memory_reserved_mb,
|
| 17 |
-
max_memory_allocated_mb, memory_total_mb.
|
| 18 |
-
"""
|
| 19 |
-
if not torch.cuda.is_available():
|
| 20 |
-
return {"available": False}
|
| 21 |
-
|
| 22 |
-
props = torch.cuda.get_device_properties(0)
|
| 23 |
-
return {
|
| 24 |
-
"available": True,
|
| 25 |
-
"name": torch.cuda.get_device_name(0),
|
| 26 |
-
"memory_allocated_mb": torch.cuda.memory_allocated(0) / (1024 * 1024),
|
| 27 |
-
"memory_reserved_mb": torch.cuda.memory_reserved(0) / (1024 * 1024),
|
| 28 |
-
"max_memory_allocated_mb": torch.cuda.max_memory_allocated(0) / (1024 * 1024),
|
| 29 |
-
"memory_total_mb": props.total_mem / (1024 * 1024),
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def check_health(
|
| 34 |
-
vram_pressure_pct: float = 90.0,
|
| 35 |
-
min_free_disk_gb: float = 1.0,
|
| 36 |
-
) -> tuple[bool, list[str]]:
|
| 37 |
-
"""Check GPU and disk health before launching an experiment.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
vram_pressure_pct: Warn when GPU memory allocation exceeds this
|
| 41 |
-
percentage of total VRAM.
|
| 42 |
-
min_free_disk_gb: Warn when free disk space falls below this.
|
| 43 |
-
|
| 44 |
-
Returns:
|
| 45 |
-
Tuple of (healthy: bool, warnings: list[str]).
|
| 46 |
-
``healthy`` is True when there are no warnings.
|
| 47 |
-
"""
|
| 48 |
-
warnings: list[str] = []
|
| 49 |
-
stats = get_gpu_stats()
|
| 50 |
-
|
| 51 |
-
if not stats["available"]:
|
| 52 |
-
return False, ["No CUDA GPU available"]
|
| 53 |
-
|
| 54 |
-
# Memory pressure check.
|
| 55 |
-
used_pct = (
|
| 56 |
-
stats["memory_allocated_mb"] / stats["memory_total_mb"] * 100
|
| 57 |
-
if stats["memory_total_mb"] > 0
|
| 58 |
-
else 0.0
|
| 59 |
-
)
|
| 60 |
-
if used_pct > vram_pressure_pct:
|
| 61 |
-
warnings.append(
|
| 62 |
-
f"GPU memory pressure: {used_pct:.1f}% allocated "
|
| 63 |
-
f"({stats['memory_allocated_mb']:.0f} / {stats['memory_total_mb']:.0f} MB)"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# Disk space check.
|
| 67 |
-
try:
|
| 68 |
-
statvfs = os.statvfs(os.path.dirname(os.path.abspath(__file__)))
|
| 69 |
-
free_gb = (statvfs.f_bavail * statvfs.f_frsize) / (1024**3)
|
| 70 |
-
if free_gb < min_free_disk_gb:
|
| 71 |
-
warnings.append(f"Low disk space: {free_gb:.2f} GB free")
|
| 72 |
-
except (AttributeError, OSError):
|
| 73 |
-
# os.statvfs not available on all platforms (e.g. Windows).
|
| 74 |
-
pass
|
| 75 |
-
|
| 76 |
-
return len(warnings) == 0, warnings
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def reset_peak_stats() -> None:
|
| 80 |
-
"""Reset GPU peak memory tracking for the next experiment.
|
| 81 |
-
|
| 82 |
-
Should be called immediately before launching each training run so
|
| 83 |
-
that peak_vram_mb reported in run.log reflects only that experiment.
|
| 84 |
-
"""
|
| 85 |
-
if torch.cuda.is_available():
|
| 86 |
-
torch.cuda.reset_peak_memory_stats()
|
|
|
|
| 1 |
+
"""Hardware health monitoring for HYDRA experiments.
|
| 2 |
+
|
| 3 |
+
Provides lightweight checks that the orchestrator runs before each
|
| 4 |
+
experiment to avoid launching training into a degraded GPU state.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_gpu_stats() -> dict:
|
| 12 |
+
"""Return current GPU memory statistics.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dict with keys: available (bool), and when available:
|
| 16 |
+
name, memory_allocated_mb, memory_reserved_mb,
|
| 17 |
+
max_memory_allocated_mb, memory_total_mb.
|
| 18 |
+
"""
|
| 19 |
+
if not torch.cuda.is_available():
|
| 20 |
+
return {"available": False}
|
| 21 |
+
|
| 22 |
+
props = torch.cuda.get_device_properties(0)
|
| 23 |
+
return {
|
| 24 |
+
"available": True,
|
| 25 |
+
"name": torch.cuda.get_device_name(0),
|
| 26 |
+
"memory_allocated_mb": torch.cuda.memory_allocated(0) / (1024 * 1024),
|
| 27 |
+
"memory_reserved_mb": torch.cuda.memory_reserved(0) / (1024 * 1024),
|
| 28 |
+
"max_memory_allocated_mb": torch.cuda.max_memory_allocated(0) / (1024 * 1024),
|
| 29 |
+
"memory_total_mb": props.total_mem / (1024 * 1024),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def check_health(
|
| 34 |
+
vram_pressure_pct: float = 90.0,
|
| 35 |
+
min_free_disk_gb: float = 1.0,
|
| 36 |
+
) -> tuple[bool, list[str]]:
|
| 37 |
+
"""Check GPU and disk health before launching an experiment.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vram_pressure_pct: Warn when GPU memory allocation exceeds this
|
| 41 |
+
percentage of total VRAM.
|
| 42 |
+
min_free_disk_gb: Warn when free disk space falls below this.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Tuple of (healthy: bool, warnings: list[str]).
|
| 46 |
+
``healthy`` is True when there are no warnings.
|
| 47 |
+
"""
|
| 48 |
+
warnings: list[str] = []
|
| 49 |
+
stats = get_gpu_stats()
|
| 50 |
+
|
| 51 |
+
if not stats["available"]:
|
| 52 |
+
return False, ["No CUDA GPU available"]
|
| 53 |
+
|
| 54 |
+
# Memory pressure check.
|
| 55 |
+
used_pct = (
|
| 56 |
+
stats["memory_allocated_mb"] / stats["memory_total_mb"] * 100
|
| 57 |
+
if stats["memory_total_mb"] > 0
|
| 58 |
+
else 0.0
|
| 59 |
+
)
|
| 60 |
+
if used_pct > vram_pressure_pct:
|
| 61 |
+
warnings.append(
|
| 62 |
+
f"GPU memory pressure: {used_pct:.1f}% allocated "
|
| 63 |
+
f"({stats['memory_allocated_mb']:.0f} / {stats['memory_total_mb']:.0f} MB)"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Disk space check.
|
| 67 |
+
try:
|
| 68 |
+
statvfs = os.statvfs(os.path.dirname(os.path.abspath(__file__)))
|
| 69 |
+
free_gb = (statvfs.f_bavail * statvfs.f_frsize) / (1024**3)
|
| 70 |
+
if free_gb < min_free_disk_gb:
|
| 71 |
+
warnings.append(f"Low disk space: {free_gb:.2f} GB free")
|
| 72 |
+
except (AttributeError, OSError):
|
| 73 |
+
# os.statvfs not available on all platforms (e.g. Windows).
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
return len(warnings) == 0, warnings
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def reset_peak_stats() -> None:
|
| 80 |
+
"""Reset GPU peak memory tracking for the next experiment.
|
| 81 |
+
|
| 82 |
+
Should be called immediately before launching each training run so
|
| 83 |
+
that peak_vram_mb reported in run.log reflects only that experiment.
|
| 84 |
+
"""
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
torch.cuda.reset_peak_memory_stats()
|
overlay/harness/meta_agent.py
CHANGED
|
@@ -1,139 +1,139 @@
|
|
| 1 |
-
"""Meta-agent: evolves program.md based on experiment history.
|
| 2 |
-
|
| 3 |
-
Runs every ``meta_interval`` inner-loop experiments (configured in
|
| 4 |
-
HarnessConfig). Reads the current research state from results.tsv,
|
| 5 |
-
decides whether guidance is needed, and appends a directive to
|
| 6 |
-
program.md. Any previous auto-generated directive is replaced so
|
| 7 |
-
the file stays clean.
|
| 8 |
-
"""
|
| 9 |
-
import os
|
| 10 |
-
|
| 11 |
-
from harness.git_utils import REPO_DIR
|
| 12 |
-
from harness.search_strategy import ResearchState, diagnose
|
| 13 |
-
|
| 14 |
-
PROGRAM_PATH = os.path.join(REPO_DIR, "program.md")
|
| 15 |
-
RESULTS_PATH = os.path.join(REPO_DIR, "results.tsv")
|
| 16 |
-
|
| 17 |
-
# Sentinel that marks auto-generated content so it can be cleanly replaced.
|
| 18 |
-
_DIRECTIVE_MARKER = "## Meta-Agent Directive (auto-generated)"
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def generate_directive(state: ResearchState) -> str | None:
|
| 22 |
-
"""Generate a directive string to append to program.md, or None.
|
| 23 |
-
|
| 24 |
-
A directive is only produced when the research state is not EXPLORING
|
| 25 |
-
(i.e., something needs to change).
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
state: Current ResearchState diagnosis.
|
| 29 |
-
|
| 30 |
-
Returns:
|
| 31 |
-
Formatted directive string, or None when no change is needed.
|
| 32 |
-
"""
|
| 33 |
-
if state.label == "EXPLORING":
|
| 34 |
-
return None
|
| 35 |
-
|
| 36 |
-
if state.label == "BROKEN":
|
| 37 |
-
return (
|
| 38 |
-
f"\n{_DIRECTIVE_MARKER}\n"
|
| 39 |
-
f"ALERT: Crash rate is {state.crash_rate:.0%} in the recent window. "
|
| 40 |
-
"Revert to the last stable commit. Reduce model complexity before "
|
| 41 |
-
"proposing further changes. Suggested actions:\n"
|
| 42 |
-
"- Reduce d_model or n_layer\n"
|
| 43 |
-
"- Reduce batch_size\n"
|
| 44 |
-
"- Disable experimental modules (Engram, mHC, Hestia) one at a time\n"
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
if state.label == "STUCK":
|
| 48 |
-
stale = state.total_experiments - state.last_improvement_at
|
| 49 |
-
return (
|
| 50 |
-
f"\n{_DIRECTIVE_MARKER}\n"
|
| 51 |
-
f"ALERT: No improvement for {stale} experiments "
|
| 52 |
-
f"(best_bpb={state.best_bpb:.6f}). "
|
| 53 |
-
"Apply BOLD changes for the next 5 experiments:\n"
|
| 54 |
-
"- Dramatically change d_model or n_layer (2× or ½)\n"
|
| 55 |
-
"- Toggle Engram or mHC on/off entirely\n"
|
| 56 |
-
"- Change optimizer hyperparameters by 3–5×\n"
|
| 57 |
-
"- Temporarily accept results within 0.5% of baseline\n"
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
if state.label == "EXPLOITING":
|
| 61 |
-
return (
|
| 62 |
-
f"\n{_DIRECTIVE_MARKER}\n"
|
| 63 |
-
"Search is converging too early. Inject diversity:\n"
|
| 64 |
-
"- If recent experiments tune LR, try architecture changes instead\n"
|
| 65 |
-
"- If tuning architecture, try optimizer or regularisation changes\n"
|
| 66 |
-
"- Try removing complexity (simplification wins are valuable)\n"
|
| 67 |
-
"- Explore a subsystem not touched in the last 10 experiments\n"
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
return None
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _strip_previous_directive(content: str) -> str:
|
| 74 |
-
"""Remove any prior auto-generated directive block from content.
|
| 75 |
-
|
| 76 |
-
Args:
|
| 77 |
-
content: Full text of program.md.
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
Content with any previous directive stripped and trailing
|
| 81 |
-
whitespace normalised.
|
| 82 |
-
"""
|
| 83 |
-
if _DIRECTIVE_MARKER in content:
|
| 84 |
-
content = content[: content.index(_DIRECTIVE_MARKER)].rstrip() + "\n"
|
| 85 |
-
return content
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def run_meta_iteration(
|
| 89 |
-
program_path: str = PROGRAM_PATH,
|
| 90 |
-
results_path: str = RESULTS_PATH,
|
| 91 |
-
) -> dict:
|
| 92 |
-
"""Run one meta-agent iteration.
|
| 93 |
-
|
| 94 |
-
Diagnoses the current research state and optionally rewrites
|
| 95 |
-
program.md with a new directive.
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
program_path: Path to program.md.
|
| 99 |
-
results_path: Path to results.tsv.
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
Summary dict with keys: state, total_experiments, best_bpb,
|
| 103 |
-
crash_rate, changed, and optionally directive.
|
| 104 |
-
"""
|
| 105 |
-
state = diagnose(results_path)
|
| 106 |
-
|
| 107 |
-
summary: dict = {
|
| 108 |
-
"state": state.label,
|
| 109 |
-
"total_experiments": state.total_experiments,
|
| 110 |
-
"best_bpb": state.best_bpb,
|
| 111 |
-
"crash_rate": state.crash_rate,
|
| 112 |
-
"changed": False,
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
directive = generate_directive(state)
|
| 116 |
-
if directive is None:
|
| 117 |
-
return summary
|
| 118 |
-
|
| 119 |
-
try:
|
| 120 |
-
with open(program_path) as fh:
|
| 121 |
-
content = fh.read()
|
| 122 |
-
except FileNotFoundError:
|
| 123 |
-
content = ""
|
| 124 |
-
|
| 125 |
-
content = _strip_previous_directive(content)
|
| 126 |
-
content = content + "\n" + directive
|
| 127 |
-
|
| 128 |
-
tmp_path = program_path + ".tmp"
|
| 129 |
-
try:
|
| 130 |
-
with open(tmp_path, "w") as fh:
|
| 131 |
-
fh.write(content)
|
| 132 |
-
os.replace(tmp_path, program_path) # atomic on POSIX
|
| 133 |
-
finally:
|
| 134 |
-
if os.path.exists(tmp_path):
|
| 135 |
-
os.unlink(tmp_path)
|
| 136 |
-
|
| 137 |
-
summary["changed"] = True
|
| 138 |
-
summary["directive"] = directive.strip()
|
| 139 |
-
return summary
|
|
|
|
| 1 |
+
"""Meta-agent: evolves program.md based on experiment history.
|
| 2 |
+
|
| 3 |
+
Runs every ``meta_interval`` inner-loop experiments (configured in
|
| 4 |
+
HarnessConfig). Reads the current research state from results.tsv,
|
| 5 |
+
decides whether guidance is needed, and appends a directive to
|
| 6 |
+
program.md. Any previous auto-generated directive is replaced so
|
| 7 |
+
the file stays clean.
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from harness.git_utils import REPO_DIR
|
| 12 |
+
from harness.search_strategy import ResearchState, diagnose
|
| 13 |
+
|
| 14 |
+
PROGRAM_PATH = os.path.join(REPO_DIR, "program.md")
|
| 15 |
+
RESULTS_PATH = os.path.join(REPO_DIR, "results.tsv")
|
| 16 |
+
|
| 17 |
+
# Sentinel that marks auto-generated content so it can be cleanly replaced.
|
| 18 |
+
_DIRECTIVE_MARKER = "## Meta-Agent Directive (auto-generated)"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_directive(state: ResearchState) -> str | None:
|
| 22 |
+
"""Generate a directive string to append to program.md, or None.
|
| 23 |
+
|
| 24 |
+
A directive is only produced when the research state is not EXPLORING
|
| 25 |
+
(i.e., something needs to change).
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
state: Current ResearchState diagnosis.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Formatted directive string, or None when no change is needed.
|
| 32 |
+
"""
|
| 33 |
+
if state.label == "EXPLORING":
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
if state.label == "BROKEN":
|
| 37 |
+
return (
|
| 38 |
+
f"\n{_DIRECTIVE_MARKER}\n"
|
| 39 |
+
f"ALERT: Crash rate is {state.crash_rate:.0%} in the recent window. "
|
| 40 |
+
"Revert to the last stable commit. Reduce model complexity before "
|
| 41 |
+
"proposing further changes. Suggested actions:\n"
|
| 42 |
+
"- Reduce d_model or n_layer\n"
|
| 43 |
+
"- Reduce batch_size\n"
|
| 44 |
+
"- Disable experimental modules (Engram, mHC, Hestia) one at a time\n"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if state.label == "STUCK":
|
| 48 |
+
stale = state.total_experiments - state.last_improvement_at
|
| 49 |
+
return (
|
| 50 |
+
f"\n{_DIRECTIVE_MARKER}\n"
|
| 51 |
+
f"ALERT: No improvement for {stale} experiments "
|
| 52 |
+
f"(best_bpb={state.best_bpb:.6f}). "
|
| 53 |
+
"Apply BOLD changes for the next 5 experiments:\n"
|
| 54 |
+
"- Dramatically change d_model or n_layer (2× or ½)\n"
|
| 55 |
+
"- Toggle Engram or mHC on/off entirely\n"
|
| 56 |
+
"- Change optimizer hyperparameters by 3–5×\n"
|
| 57 |
+
"- Temporarily accept results within 0.5% of baseline\n"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if state.label == "EXPLOITING":
|
| 61 |
+
return (
|
| 62 |
+
f"\n{_DIRECTIVE_MARKER}\n"
|
| 63 |
+
"Search is converging too early. Inject diversity:\n"
|
| 64 |
+
"- If recent experiments tune LR, try architecture changes instead\n"
|
| 65 |
+
"- If tuning architecture, try optimizer or regularisation changes\n"
|
| 66 |
+
"- Try removing complexity (simplification wins are valuable)\n"
|
| 67 |
+
"- Explore a subsystem not touched in the last 10 experiments\n"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _strip_previous_directive(content: str) -> str:
|
| 74 |
+
"""Remove any prior auto-generated directive block from content.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
content: Full text of program.md.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Content with any previous directive stripped and trailing
|
| 81 |
+
whitespace normalised.
|
| 82 |
+
"""
|
| 83 |
+
if _DIRECTIVE_MARKER in content:
|
| 84 |
+
content = content[: content.index(_DIRECTIVE_MARKER)].rstrip() + "\n"
|
| 85 |
+
return content
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def run_meta_iteration(
|
| 89 |
+
program_path: str = PROGRAM_PATH,
|
| 90 |
+
results_path: str = RESULTS_PATH,
|
| 91 |
+
) -> dict:
|
| 92 |
+
"""Run one meta-agent iteration.
|
| 93 |
+
|
| 94 |
+
Diagnoses the current research state and optionally rewrites
|
| 95 |
+
program.md with a new directive.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
program_path: Path to program.md.
|
| 99 |
+
results_path: Path to results.tsv.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Summary dict with keys: state, total_experiments, best_bpb,
|
| 103 |
+
crash_rate, changed, and optionally directive.
|
| 104 |
+
"""
|
| 105 |
+
state = diagnose(results_path)
|
| 106 |
+
|
| 107 |
+
summary: dict = {
|
| 108 |
+
"state": state.label,
|
| 109 |
+
"total_experiments": state.total_experiments,
|
| 110 |
+
"best_bpb": state.best_bpb,
|
| 111 |
+
"crash_rate": state.crash_rate,
|
| 112 |
+
"changed": False,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
directive = generate_directive(state)
|
| 116 |
+
if directive is None:
|
| 117 |
+
return summary
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
with open(program_path) as fh:
|
| 121 |
+
content = fh.read()
|
| 122 |
+
except FileNotFoundError:
|
| 123 |
+
content = ""
|
| 124 |
+
|
| 125 |
+
content = _strip_previous_directive(content)
|
| 126 |
+
content = content + "\n" + directive
|
| 127 |
+
|
| 128 |
+
tmp_path = program_path + ".tmp"
|
| 129 |
+
try:
|
| 130 |
+
with open(tmp_path, "w") as fh:
|
| 131 |
+
fh.write(content)
|
| 132 |
+
os.replace(tmp_path, program_path) # atomic on POSIX
|
| 133 |
+
finally:
|
| 134 |
+
if os.path.exists(tmp_path):
|
| 135 |
+
os.unlink(tmp_path)
|
| 136 |
+
|
| 137 |
+
summary["changed"] = True
|
| 138 |
+
summary["directive"] = directive.strip()
|
| 139 |
+
return summary
|
overlay/harness/orchestrator.py
CHANGED
|
@@ -1,296 +1,293 @@
|
|
| 1 |
-
"""HYDRA Orchestrator: main loop for autonomous research.
|
| 2 |
-
|
| 3 |
-
Usage::
|
| 4 |
-
|
| 5 |
-
python -m harness.orchestrator [--meta-interval N] [--max-experiments N]
|
| 6 |
-
|
| 7 |
-
Loop:
|
| 8 |
-
1. Read current state (branch, results.tsv, program.md)
|
| 9 |
-
2. [Architect Agent] proposes and applies changes to train.py (external)
|
| 10 |
-
3. Git commit the changes
|
| 11 |
-
4. Run training: ``uv run train.py`` captured to run.log
|
| 12 |
-
5. [Eval Agent] extract metrics from run.log
|
| 13 |
-
6. Keep or discard based on val_bpb + secondary metric gates
|
| 14 |
-
7. Log to results.tsv
|
| 15 |
-
8. Every ``meta_interval`` experiments: [Meta Agent] evolves program.md
|
| 16 |
-
9. Repeat
|
| 17 |
-
|
| 18 |
-
The orchestrator intentionally does NOT modify train.py itself -- it
|
| 19 |
-
provides the infrastructure ("rails") that the autoresearch loop runs on.
|
| 20 |
-
"""
|
| 21 |
-
import argparse
|
| 22 |
-
import csv
|
| 23 |
import os
|
| 24 |
import subprocess
|
| 25 |
import time
|
| 26 |
|
| 27 |
-
from configs.harness_config import HarnessConfig
|
| 28 |
from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
|
| 29 |
-
from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
|
| 30 |
-
from harness.health_monitor import check_health, reset_peak_stats
|
| 31 |
-
from harness.meta_agent import run_meta_iteration
|
| 32 |
-
from harness.search_strategy import diagnose
|
| 33 |
-
|
| 34 |
-
# ---------------------------------------------------------------------------
|
| 35 |
-
# Paths
|
| 36 |
-
# ---------------------------------------------------------------------------
|
| 37 |
-
|
| 38 |
-
RESULTS_FILE = os.path.join(REPO_DIR, "results.tsv")
|
| 39 |
-
RUN_LOG = os.path.join(REPO_DIR, "run.log")
|
| 40 |
-
|
| 41 |
-
_TSV_HEADER = "commit\tval_bpb\tmemory_gb\tstatus\tdescription\n"
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# ---------------------------------------------------------------------------
|
| 45 |
-
# TSV helpers
|
| 46 |
-
# ---------------------------------------------------------------------------
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def init_results_tsv() -> None:
|
| 50 |
-
"""Create results.tsv with header row if it does not yet exist."""
|
| 51 |
-
if not os.path.exists(RESULTS_FILE):
|
| 52 |
-
with open(RESULTS_FILE, "w") as fh:
|
| 53 |
-
fh.write(_TSV_HEADER)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def log_result(
|
| 57 |
-
commit: str,
|
| 58 |
-
val_bpb: float,
|
| 59 |
-
memory_gb: float,
|
| 60 |
-
status: str,
|
| 61 |
-
description: str,
|
| 62 |
-
) -> None:
|
| 63 |
-
"""Append one row to results.tsv.
|
| 64 |
-
|
| 65 |
-
Args:
|
| 66 |
-
commit: Short git hash for this experiment.
|
| 67 |
-
val_bpb: Validation bits-per-byte (0.0 for crashes).
|
| 68 |
-
memory_gb: Peak VRAM usage in gigabytes.
|
| 69 |
-
status: One of keep / discard / crash / timeout.
|
| 70 |
-
description: Short human-readable description.
|
| 71 |
-
"""
|
| 72 |
-
with open(RESULTS_FILE, "a") as fh:
|
| 73 |
-
fh.write(
|
| 74 |
-
f"{commit}\t{val_bpb:.6f}\t{memory_gb:.2f}\t{status}\t{description}\n"
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def count_experiments() -> int:
|
| 79 |
-
"""Count the number of experiment rows in results.tsv.
|
| 80 |
-
|
| 81 |
-
Returns:
|
| 82 |
-
Row count excluding the header line (0 when file does not exist).
|
| 83 |
-
"""
|
| 84 |
-
if not os.path.exists(RESULTS_FILE):
|
| 85 |
-
return 0
|
| 86 |
-
with open(RESULTS_FILE) as fh:
|
| 87 |
-
return max(0, sum(1 for _ in fh) - 1)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _load_best_bpb() -> float:
|
| 91 |
-
"""Scan results.tsv for the best (lowest positive) val_bpb seen so far.
|
| 92 |
-
|
| 93 |
-
Returns:
|
| 94 |
-
Best val_bpb, or ``float("inf")`` when no valid result exists.
|
| 95 |
-
"""
|
| 96 |
-
if not os.path.exists(RESULTS_FILE):
|
| 97 |
-
return float("inf")
|
| 98 |
-
best = float("inf")
|
| 99 |
-
with open(RESULTS_FILE) as fh:
|
| 100 |
-
reader = csv.DictReader(fh, delimiter="\t")
|
| 101 |
-
for row in reader:
|
| 102 |
-
try:
|
| 103 |
-
bpb = float(row.get("val_bpb", "0") or "0")
|
| 104 |
-
except ValueError:
|
| 105 |
-
continue
|
| 106 |
-
if 0 < bpb < best:
|
| 107 |
-
best = bpb
|
| 108 |
-
return best
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# ---------------------------------------------------------------------------
|
| 112 |
-
# Experiment execution
|
| 113 |
-
# ---------------------------------------------------------------------------
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def run_experiment(timeout: int = 600) -> str:
|
| 117 |
-
"""Launch ``uv run train.py`` and capture all output to run.log.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
timeout: Kill the process after this many seconds.
|
| 121 |
-
|
| 122 |
-
Returns:
|
| 123 |
-
One of ``"ok"``, ``"timeout"``, or ``"error"``.
|
| 124 |
-
"""
|
| 125 |
-
try:
|
| 126 |
-
with open(RUN_LOG, "w") as log_file:
|
| 127 |
-
proc = subprocess.run(
|
| 128 |
-
["uv", "run", "train.py"],
|
| 129 |
-
cwd=REPO_DIR,
|
| 130 |
-
stdout=log_file,
|
| 131 |
-
stderr=subprocess.STDOUT,
|
| 132 |
-
timeout=timeout,
|
| 133 |
-
)
|
| 134 |
-
return "ok" if proc.returncode == 0 else "error"
|
| 135 |
-
except subprocess.TimeoutExpired:
|
| 136 |
-
return "timeout"
|
| 137 |
-
except Exception as exc: # noqa: BLE001
|
| 138 |
-
with open(RUN_LOG, "a") as log_file:
|
| 139 |
-
log_file.write(f"\nOrchestrator error: {exc}\n")
|
| 140 |
-
return "error"
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
# ---------------------------------------------------------------------------
|
| 144 |
-
# Main loop
|
| 145 |
-
# ---------------------------------------------------------------------------
|
| 146 |
-
|
| 147 |
-
|
| 148 |
def run_loop(
|
| 149 |
meta_interval: int = 20,
|
| 150 |
max_experiments: int | None = None,
|
| 151 |
experiment_timeout: int = 600,
|
| 152 |
-
secondary_gates: dict
|
| 153 |
) -> None:
|
| 154 |
-
"""Run the HYDRA autoresearch loop.
|
| 155 |
-
|
| 156 |
-
This function runs indefinitely (or until ``max_experiments`` is reached
|
| 157 |
-
or the user interrupts with Ctrl-C).
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
meta_interval: Run the meta-agent every N experiments.
|
| 161 |
-
max_experiments: Hard stop after this many experiments (None = infinite).
|
| 162 |
-
experiment_timeout: Seconds before a training run is killed.
|
| 163 |
-
secondary_gates: Optional gate thresholds forwarded to
|
| 164 |
-
:func:`~harness.eval_agent.should_keep`.
|
| 165 |
-
"""
|
| 166 |
init_results_tsv()
|
| 167 |
-
if secondary_gates is None:
|
| 168 |
-
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 169 |
best_bpb = _load_best_bpb()
|
| 170 |
-
experiment_num = count_experiments()
|
| 171 |
-
|
| 172 |
-
print(
|
| 173 |
-
f"HYDRA Orchestrator starting. "
|
| 174 |
-
f"Experiments so far: {experiment_num}, Best BPB: {best_bpb:.6f}"
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
while max_experiments is None or experiment_num < max_experiments:
|
| 178 |
-
experiment_num += 1
|
| 179 |
-
|
| 180 |
-
# ------------------------------------------------------------------
|
| 181 |
-
# Pre-flight health check
|
| 182 |
-
# ------------------------------------------------------------------
|
| 183 |
-
healthy, hw_warnings = check_health()
|
| 184 |
-
if hw_warnings:
|
| 185 |
-
print(f" [health] {hw_warnings}")
|
| 186 |
-
|
| 187 |
-
# ------------------------------------------------------------------
|
| 188 |
-
# Periodic meta-agent update
|
| 189 |
-
# ------------------------------------------------------------------
|
| 190 |
-
if experiment_num > 1 and experiment_num % meta_interval == 0:
|
| 191 |
-
print(f"\n=== Meta-agent iteration at experiment {experiment_num} ===")
|
| 192 |
-
meta_result = run_meta_iteration()
|
| 193 |
-
print(
|
| 194 |
-
f" state={meta_result['state']} "
|
| 195 |
-
f"best_bpb={meta_result['best_bpb']:.6f} "
|
| 196 |
-
f"changed={meta_result['changed']}"
|
| 197 |
-
)
|
| 198 |
-
if meta_result.get("directive"):
|
| 199 |
-
print(f" directive: {meta_result['directive'][:120]}")
|
| 200 |
-
|
| 201 |
-
# ------------------------------------------------------------------
|
| 202 |
-
# Record baseline commit so we can reset on failure / discard
|
| 203 |
-
# ------------------------------------------------------------------
|
| 204 |
-
pre_commit = current_commit_short()
|
| 205 |
-
|
| 206 |
-
# ------------------------------------------------------------------
|
| 207 |
-
# Run experiment
|
| 208 |
-
# ------------------------------------------------------------------
|
| 209 |
-
print(f"\n--- Experiment {experiment_num} ---")
|
| 210 |
-
reset_peak_stats()
|
| 211 |
-
t0 = time.time()
|
| 212 |
-
run_status = run_experiment(timeout=experiment_timeout)
|
| 213 |
-
elapsed = time.time() - t0
|
| 214 |
-
print(f" run_status={run_status} elapsed={elapsed:.1f}s")
|
| 215 |
-
|
| 216 |
-
# ------------------------------------------------------------------
|
| 217 |
-
# Parse results
|
| 218 |
-
# ------------------------------------------------------------------
|
| 219 |
-
result: ExperimentResult = parse_run_log(RUN_LOG)
|
| 220 |
-
|
| 221 |
-
if result.crashed or run_status != "ok":
|
| 222 |
-
commit = current_commit_short()
|
| 223 |
-
err_short = (
|
| 224 |
-
"timeout"
|
| 225 |
-
if run_status == "timeout"
|
| 226 |
-
else result.error_message[:80].replace("\n", " ")
|
| 227 |
-
)
|
| 228 |
-
log_result(commit, 0.0, 0.0, "crash", err_short)
|
| 229 |
-
print(f" CRASH: {err_short}")
|
| 230 |
-
reset_to(pre_commit)
|
| 231 |
-
continue
|
| 232 |
-
|
| 233 |
-
# ------------------------------------------------------------------
|
| 234 |
-
# Secondary alarms (non-blocking -- logged but do not abort)
|
| 235 |
-
# ------------------------------------------------------------------
|
| 236 |
-
alarms = check_secondary_alarms(result)
|
| 237 |
-
if alarms:
|
| 238 |
-
for alarm in alarms:
|
| 239 |
-
print(f" [alarm] {alarm}")
|
| 240 |
-
|
| 241 |
-
# ------------------------------------------------------------------
|
| 242 |
-
# Keep / discard
|
| 243 |
-
# ------------------------------------------------------------------
|
| 244 |
-
keep, reason = should_keep(result, best_bpb, gates=secondary_gates)
|
| 245 |
-
commit = current_commit_short()
|
| 246 |
-
memory_gb = result.peak_vram_mb / 1024.0
|
| 247 |
-
|
| 248 |
-
if keep:
|
| 249 |
-
best_bpb = result.val_bpb
|
| 250 |
-
description = f"val_bpb improved to {result.val_bpb:.6f}"
|
| 251 |
-
log_result(commit, result.val_bpb, memory_gb, "keep", description)
|
| 252 |
-
print(f" KEEP: val_bpb={result.val_bpb:.6f} (new best)")
|
| 253 |
-
else:
|
| 254 |
-
description = f"{reason} val_bpb={result.val_bpb:.6f}"
|
| 255 |
-
log_result(commit, result.val_bpb, memory_gb, "discard", description)
|
| 256 |
-
print(f" DISCARD: val_bpb={result.val_bpb:.6f} ({reason})")
|
| 257 |
-
reset_to(pre_commit)
|
| 258 |
-
|
| 259 |
-
print(f"\nHYDRA finished after {experiment_num} experiments. Best BPB: {best_bpb:.6f}")
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
# ---------------------------------------------------------------------------
|
| 263 |
-
# CLI entry point
|
| 264 |
-
# ---------------------------------------------------------------------------
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
if __name__ == "__main__":
|
| 268 |
-
parser = argparse.ArgumentParser(description="HYDRA Autoresearch Orchestrator")
|
| 269 |
-
parser.add_argument(
|
| 270 |
-
"--meta-interval",
|
| 271 |
-
type=int,
|
| 272 |
-
default=20,
|
| 273 |
-
help="Run meta-agent every N experiments (default: 20)",
|
| 274 |
-
)
|
| 275 |
-
parser.add_argument(
|
| 276 |
-
"--max-experiments",
|
| 277 |
-
type=int,
|
| 278 |
-
default=None,
|
| 279 |
-
help="Stop after N experiments; omit for infinite (default: infinite)",
|
| 280 |
-
)
|
| 281 |
-
parser.add_argument(
|
| 282 |
-
"--experiment-timeout",
|
| 283 |
-
type=int,
|
| 284 |
-
default=600,
|
| 285 |
-
help="Kill training run after N seconds (default: 600)",
|
| 286 |
-
)
|
| 287 |
-
args = parser.parse_args()
|
| 288 |
-
|
| 289 |
-
try:
|
| 290 |
-
run_loop(
|
| 291 |
-
meta_interval=args.meta_interval,
|
| 292 |
-
max_experiments=args.max_experiments,
|
| 293 |
-
experiment_timeout=args.experiment_timeout,
|
| 294 |
-
)
|
| 295 |
-
except KeyboardInterrupt:
|
| 296 |
-
print("\nOrchestrator stopped by user.")
|
|
|
|
| 1 |
+
"""HYDRA Orchestrator: main loop for autonomous research.
|
| 2 |
+
|
| 3 |
+
Usage::
|
| 4 |
+
|
| 5 |
+
python -m harness.orchestrator [--meta-interval N] [--max-experiments N]
|
| 6 |
+
|
| 7 |
+
Loop:
|
| 8 |
+
1. Read current state (branch, results.tsv, program.md)
|
| 9 |
+
2. [Architect Agent] proposes and applies changes to train.py (external)
|
| 10 |
+
3. Git commit the changes
|
| 11 |
+
4. Run training: ``uv run train.py`` captured to run.log
|
| 12 |
+
5. [Eval Agent] extract metrics from run.log
|
| 13 |
+
6. Keep or discard based on val_bpb + secondary metric gates
|
| 14 |
+
7. Log to results.tsv
|
| 15 |
+
8. Every ``meta_interval`` experiments: [Meta Agent] evolves program.md
|
| 16 |
+
9. Repeat
|
| 17 |
+
|
| 18 |
+
The orchestrator intentionally does NOT modify train.py itself -- it
|
| 19 |
+
provides the infrastructure ("rails") that the autoresearch loop runs on.
|
| 20 |
+
"""
|
| 21 |
+
import argparse
|
| 22 |
+
import csv
|
| 23 |
import os
|
| 24 |
import subprocess
|
| 25 |
import time
|
| 26 |
|
|
|
|
| 27 |
from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
|
| 28 |
+
from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
|
| 29 |
+
from harness.health_monitor import check_health, reset_peak_stats
|
| 30 |
+
from harness.meta_agent import run_meta_iteration
|
| 31 |
+
from harness.search_strategy import diagnose
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Paths
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
RESULTS_FILE = os.path.join(REPO_DIR, "results.tsv")
|
| 38 |
+
RUN_LOG = os.path.join(REPO_DIR, "run.log")
|
| 39 |
+
|
| 40 |
+
_TSV_HEADER = "commit\tval_bpb\tmemory_gb\tstatus\tdescription\n"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# TSV helpers
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def init_results_tsv() -> None:
|
| 49 |
+
"""Create results.tsv with header row if it does not yet exist."""
|
| 50 |
+
if not os.path.exists(RESULTS_FILE):
|
| 51 |
+
with open(RESULTS_FILE, "w") as fh:
|
| 52 |
+
fh.write(_TSV_HEADER)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def log_result(
|
| 56 |
+
commit: str,
|
| 57 |
+
val_bpb: float,
|
| 58 |
+
memory_gb: float,
|
| 59 |
+
status: str,
|
| 60 |
+
description: str,
|
| 61 |
+
) -> None:
|
| 62 |
+
"""Append one row to results.tsv.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
commit: Short git hash for this experiment.
|
| 66 |
+
val_bpb: Validation bits-per-byte (0.0 for crashes).
|
| 67 |
+
memory_gb: Peak VRAM usage in gigabytes.
|
| 68 |
+
status: One of keep / discard / crash / timeout.
|
| 69 |
+
description: Short human-readable description.
|
| 70 |
+
"""
|
| 71 |
+
with open(RESULTS_FILE, "a") as fh:
|
| 72 |
+
fh.write(
|
| 73 |
+
f"{commit}\t{val_bpb:.6f}\t{memory_gb:.2f}\t{status}\t{description}\n"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def count_experiments() -> int:
|
| 78 |
+
"""Count the number of experiment rows in results.tsv.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Row count excluding the header line (0 when file does not exist).
|
| 82 |
+
"""
|
| 83 |
+
if not os.path.exists(RESULTS_FILE):
|
| 84 |
+
return 0
|
| 85 |
+
with open(RESULTS_FILE) as fh:
|
| 86 |
+
return max(0, sum(1 for _ in fh) - 1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _load_best_bpb() -> float:
|
| 90 |
+
"""Scan results.tsv for the best (lowest positive) val_bpb seen so far.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Best val_bpb, or ``float("inf")`` when no valid result exists.
|
| 94 |
+
"""
|
| 95 |
+
if not os.path.exists(RESULTS_FILE):
|
| 96 |
+
return float("inf")
|
| 97 |
+
best = float("inf")
|
| 98 |
+
with open(RESULTS_FILE) as fh:
|
| 99 |
+
reader = csv.DictReader(fh, delimiter="\t")
|
| 100 |
+
for row in reader:
|
| 101 |
+
try:
|
| 102 |
+
bpb = float(row.get("val_bpb", "0") or "0")
|
| 103 |
+
except ValueError:
|
| 104 |
+
continue
|
| 105 |
+
if 0 < bpb < best:
|
| 106 |
+
best = bpb
|
| 107 |
+
return best
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# Experiment execution
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def run_experiment(timeout: int = 600) -> str:
|
| 116 |
+
"""Launch ``uv run train.py`` and capture all output to run.log.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
timeout: Kill the process after this many seconds.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
One of ``"ok"``, ``"timeout"``, or ``"error"``.
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
with open(RUN_LOG, "w") as log_file:
|
| 126 |
+
proc = subprocess.run(
|
| 127 |
+
["uv", "run", "train.py"],
|
| 128 |
+
cwd=REPO_DIR,
|
| 129 |
+
stdout=log_file,
|
| 130 |
+
stderr=subprocess.STDOUT,
|
| 131 |
+
timeout=timeout,
|
| 132 |
+
)
|
| 133 |
+
return "ok" if proc.returncode == 0 else "error"
|
| 134 |
+
except subprocess.TimeoutExpired:
|
| 135 |
+
return "timeout"
|
| 136 |
+
except Exception as exc: # noqa: BLE001
|
| 137 |
+
with open(RUN_LOG, "a") as log_file:
|
| 138 |
+
log_file.write(f"\nOrchestrator error: {exc}\n")
|
| 139 |
+
return "error"
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
# Main loop
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
+
|
| 147 |
def run_loop(
|
| 148 |
meta_interval: int = 20,
|
| 149 |
max_experiments: int | None = None,
|
| 150 |
experiment_timeout: int = 600,
|
| 151 |
+
secondary_gates: dict | None = None,
|
| 152 |
) -> None:
|
| 153 |
+
"""Run the HYDRA autoresearch loop.
|
| 154 |
+
|
| 155 |
+
This function runs indefinitely (or until ``max_experiments`` is reached
|
| 156 |
+
or the user interrupts with Ctrl-C).
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
meta_interval: Run the meta-agent every N experiments.
|
| 160 |
+
max_experiments: Hard stop after this many experiments (None = infinite).
|
| 161 |
+
experiment_timeout: Seconds before a training run is killed.
|
| 162 |
+
secondary_gates: Optional gate thresholds forwarded to
|
| 163 |
+
:func:`~harness.eval_agent.should_keep`.
|
| 164 |
+
"""
|
| 165 |
init_results_tsv()
|
|
|
|
|
|
|
| 166 |
best_bpb = _load_best_bpb()
|
| 167 |
+
experiment_num = count_experiments()
|
| 168 |
+
|
| 169 |
+
print(
|
| 170 |
+
f"HYDRA Orchestrator starting. "
|
| 171 |
+
f"Experiments so far: {experiment_num}, Best BPB: {best_bpb:.6f}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
while max_experiments is None or experiment_num < max_experiments:
|
| 175 |
+
experiment_num += 1
|
| 176 |
+
|
| 177 |
+
# ------------------------------------------------------------------
|
| 178 |
+
# Pre-flight health check
|
| 179 |
+
# ------------------------------------------------------------------
|
| 180 |
+
healthy, hw_warnings = check_health()
|
| 181 |
+
if hw_warnings:
|
| 182 |
+
print(f" [health] {hw_warnings}")
|
| 183 |
+
|
| 184 |
+
# ------------------------------------------------------------------
|
| 185 |
+
# Periodic meta-agent update
|
| 186 |
+
# ------------------------------------------------------------------
|
| 187 |
+
if experiment_num > 1 and experiment_num % meta_interval == 0:
|
| 188 |
+
print(f"\n=== Meta-agent iteration at experiment {experiment_num} ===")
|
| 189 |
+
meta_result = run_meta_iteration()
|
| 190 |
+
print(
|
| 191 |
+
f" state={meta_result['state']} "
|
| 192 |
+
f"best_bpb={meta_result['best_bpb']:.6f} "
|
| 193 |
+
f"changed={meta_result['changed']}"
|
| 194 |
+
)
|
| 195 |
+
if meta_result.get("directive"):
|
| 196 |
+
print(f" directive: {meta_result['directive'][:120]}")
|
| 197 |
+
|
| 198 |
+
# ------------------------------------------------------------------
|
| 199 |
+
# Record baseline commit so we can reset on failure / discard
|
| 200 |
+
# ------------------------------------------------------------------
|
| 201 |
+
pre_commit = current_commit_short()
|
| 202 |
+
|
| 203 |
+
# ------------------------------------------------------------------
|
| 204 |
+
# Run experiment
|
| 205 |
+
# ------------------------------------------------------------------
|
| 206 |
+
print(f"\n--- Experiment {experiment_num} ---")
|
| 207 |
+
reset_peak_stats()
|
| 208 |
+
t0 = time.time()
|
| 209 |
+
run_status = run_experiment(timeout=experiment_timeout)
|
| 210 |
+
elapsed = time.time() - t0
|
| 211 |
+
print(f" run_status={run_status} elapsed={elapsed:.1f}s")
|
| 212 |
+
|
| 213 |
+
# ------------------------------------------------------------------
|
| 214 |
+
# Parse results
|
| 215 |
+
# ------------------------------------------------------------------
|
| 216 |
+
result: ExperimentResult = parse_run_log(RUN_LOG)
|
| 217 |
+
|
| 218 |
+
if result.crashed or run_status != "ok":
|
| 219 |
+
commit = current_commit_short()
|
| 220 |
+
err_short = (
|
| 221 |
+
"timeout"
|
| 222 |
+
if run_status == "timeout"
|
| 223 |
+
else result.error_message[:80].replace("\n", " ")
|
| 224 |
+
)
|
| 225 |
+
log_result(commit, 0.0, 0.0, "crash", err_short)
|
| 226 |
+
print(f" CRASH: {err_short}")
|
| 227 |
+
reset_to(pre_commit)
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
# ------------------------------------------------------------------
|
| 231 |
+
# Secondary alarms (non-blocking -- logged but do not abort)
|
| 232 |
+
# ------------------------------------------------------------------
|
| 233 |
+
alarms = check_secondary_alarms(result)
|
| 234 |
+
if alarms:
|
| 235 |
+
for alarm in alarms:
|
| 236 |
+
print(f" [alarm] {alarm}")
|
| 237 |
+
|
| 238 |
+
# ------------------------------------------------------------------
|
| 239 |
+
# Keep / discard
|
| 240 |
+
# ------------------------------------------------------------------
|
| 241 |
+
keep, reason = should_keep(result, best_bpb, gates=secondary_gates)
|
| 242 |
+
commit = current_commit_short()
|
| 243 |
+
memory_gb = result.peak_vram_mb / 1024.0
|
| 244 |
+
|
| 245 |
+
if keep:
|
| 246 |
+
best_bpb = result.val_bpb
|
| 247 |
+
description = f"val_bpb improved to {result.val_bpb:.6f}"
|
| 248 |
+
log_result(commit, result.val_bpb, memory_gb, "keep", description)
|
| 249 |
+
print(f" KEEP: val_bpb={result.val_bpb:.6f} (new best)")
|
| 250 |
+
else:
|
| 251 |
+
description = f"{reason} val_bpb={result.val_bpb:.6f}"
|
| 252 |
+
log_result(commit, result.val_bpb, memory_gb, "discard", description)
|
| 253 |
+
print(f" DISCARD: val_bpb={result.val_bpb:.6f} ({reason})")
|
| 254 |
+
reset_to(pre_commit)
|
| 255 |
+
|
| 256 |
+
print(f"\nHYDRA finished after {experiment_num} experiments. Best BPB: {best_bpb:.6f}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
# CLI entry point
|
| 261 |
+
# ---------------------------------------------------------------------------
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
parser = argparse.ArgumentParser(description="HYDRA Autoresearch Orchestrator")
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--meta-interval",
|
| 268 |
+
type=int,
|
| 269 |
+
default=20,
|
| 270 |
+
help="Run meta-agent every N experiments (default: 20)",
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--max-experiments",
|
| 274 |
+
type=int,
|
| 275 |
+
default=None,
|
| 276 |
+
help="Stop after N experiments; omit for infinite (default: infinite)",
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--experiment-timeout",
|
| 280 |
+
type=int,
|
| 281 |
+
default=600,
|
| 282 |
+
help="Kill training run after N seconds (default: 600)",
|
| 283 |
+
)
|
| 284 |
+
args = parser.parse_args()
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
run_loop(
|
| 288 |
+
meta_interval=args.meta_interval,
|
| 289 |
+
max_experiments=args.max_experiments,
|
| 290 |
+
experiment_timeout=args.experiment_timeout,
|
| 291 |
+
)
|
| 292 |
+
except KeyboardInterrupt:
|
| 293 |
+
print("\nOrchestrator stopped by user.")
|
overlay/harness/search_strategy.py
CHANGED
|
@@ -1,153 +1,153 @@
|
|
| 1 |
-
"""Search strategy for HYDRA's meta-evolution loop.
|
| 2 |
-
|
| 3 |
-
Reads results.tsv and diagnoses the current research state as one of:
|
| 4 |
-
EXPLORING -- active improvement trend with diverse experiments
|
| 5 |
-
EXPLOITING -- narrowing in on a local optimum (low diversity)
|
| 6 |
-
STUCK -- no improvement for >= stuck_threshold experiments
|
| 7 |
-
BROKEN -- crash rate exceeds crash_threshold
|
| 8 |
-
"""
|
| 9 |
-
import csv
|
| 10 |
-
import os
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
@dataclass
|
| 15 |
-
class ResearchState:
|
| 16 |
-
"""Diagnosis of the current research trajectory.
|
| 17 |
-
|
| 18 |
-
Attributes:
|
| 19 |
-
label: One of EXPLORING, EXPLOITING, STUCK, BROKEN.
|
| 20 |
-
trend_improving: True when the second half of the recent window is
|
| 21 |
-
better (lower BPB) than the first half.
|
| 22 |
-
experiment_diversity: Rough 0–1 score based on unique description
|
| 23 |
-
prefixes in the recent window.
|
| 24 |
-
crash_rate: Fraction of recent experiments that crashed.
|
| 25 |
-
best_bpb: Lowest val_bpb seen across all experiments.
|
| 26 |
-
last_improvement_at: Ordinal of the experiment that set best_bpb.
|
| 27 |
-
total_experiments: Total rows in results.tsv (excluding header).
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
label: str
|
| 31 |
-
trend_improving: bool
|
| 32 |
-
experiment_diversity: float
|
| 33 |
-
crash_rate: float
|
| 34 |
-
best_bpb: float
|
| 35 |
-
last_improvement_at: int
|
| 36 |
-
total_experiments: int
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def diagnose(
|
| 40 |
-
results_path: str,
|
| 41 |
-
window: int = 20,
|
| 42 |
-
stuck_threshold: int = 10,
|
| 43 |
-
crash_threshold: float = 0.5,
|
| 44 |
-
) -> ResearchState:
|
| 45 |
-
"""Diagnose current research state from results.tsv.
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
results_path: Path to the tab-separated results file.
|
| 49 |
-
window: Number of recent experiments to consider for trend/diversity.
|
| 50 |
-
stuck_threshold: Experiments without improvement before labelling STUCK.
|
| 51 |
-
crash_threshold: Crash fraction above which state becomes BROKEN.
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
ResearchState with diagnosis label and supporting statistics.
|
| 55 |
-
"""
|
| 56 |
-
if not os.path.exists(results_path):
|
| 57 |
-
return ResearchState(
|
| 58 |
-
label="EXPLORING",
|
| 59 |
-
trend_improving=False,
|
| 60 |
-
experiment_diversity=0.0,
|
| 61 |
-
crash_rate=0.0,
|
| 62 |
-
best_bpb=float("inf"),
|
| 63 |
-
last_improvement_at=0,
|
| 64 |
-
total_experiments=0,
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
rows: list[dict] = []
|
| 68 |
-
with open(results_path) as fh:
|
| 69 |
-
reader = csv.DictReader(fh, delimiter="\t")
|
| 70 |
-
for row in reader:
|
| 71 |
-
rows.append(row)
|
| 72 |
-
|
| 73 |
-
if not rows:
|
| 74 |
-
return ResearchState(
|
| 75 |
-
label="EXPLORING",
|
| 76 |
-
trend_improving=False,
|
| 77 |
-
experiment_diversity=0.0,
|
| 78 |
-
crash_rate=0.0,
|
| 79 |
-
best_bpb=float("inf"),
|
| 80 |
-
last_improvement_at=0,
|
| 81 |
-
total_experiments=0,
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
total = len(rows)
|
| 85 |
-
recent = rows[-window:]
|
| 86 |
-
|
| 87 |
-
# Crash rate in the recent window.
|
| 88 |
-
crashes = sum(1 for r in recent if r.get("status") == "crash")
|
| 89 |
-
crash_rate = crashes / len(recent) if recent else 0.0
|
| 90 |
-
|
| 91 |
-
# Best BPB overall and which experiment achieved it.
|
| 92 |
-
best_bpb = float("inf")
|
| 93 |
-
last_improvement_at = 0
|
| 94 |
-
for i, row in enumerate(rows):
|
| 95 |
-
try:
|
| 96 |
-
bpb = float(row.get("val_bpb", "0") or "0")
|
| 97 |
-
except ValueError:
|
| 98 |
-
continue
|
| 99 |
-
if bpb > 0 and bpb < best_bpb:
|
| 100 |
-
best_bpb = bpb
|
| 101 |
-
last_improvement_at = i + 1
|
| 102 |
-
|
| 103 |
-
# Trend: is the second half of the recent window better than the first?
|
| 104 |
-
valid_bpbs = [
|
| 105 |
-
float(r.get("val_bpb", "0") or "0")
|
| 106 |
-
for r in recent
|
| 107 |
-
if float(r.get("val_bpb", "0") or "0") > 0
|
| 108 |
-
]
|
| 109 |
-
trend_improving = False
|
| 110 |
-
if len(valid_bpbs) >= 4:
|
| 111 |
-
mid = len(valid_bpbs) // 2
|
| 112 |
-
first_half_mean = sum(valid_bpbs[:mid]) / mid
|
| 113 |
-
second_half_mean = sum(valid_bpbs[mid:]) / (len(valid_bpbs) - mid)
|
| 114 |
-
trend_improving = second_half_mean < first_half_mean
|
| 115 |
-
|
| 116 |
-
# Diversity: fraction of unique description prefixes (first 20 chars).
|
| 117 |
-
descriptions = {r.get("description", "")[:20] for r in recent}
|
| 118 |
-
diversity = min(1.0, len(descriptions) / max(1, len(recent)))
|
| 119 |
-
|
| 120 |
-
# Classify state.
|
| 121 |
-
stale = total - last_improvement_at
|
| 122 |
-
if crash_rate > crash_threshold:
|
| 123 |
-
label = "BROKEN"
|
| 124 |
-
elif stale >= stuck_threshold:
|
| 125 |
-
label = "STUCK"
|
| 126 |
-
elif trend_improving and diversity > 0.3:
|
| 127 |
-
label = "EXPLORING"
|
| 128 |
-
else:
|
| 129 |
-
label = "EXPLOITING"
|
| 130 |
-
|
| 131 |
-
return ResearchState(
|
| 132 |
-
label=label,
|
| 133 |
-
trend_improving=trend_improving,
|
| 134 |
-
experiment_diversity=diversity,
|
| 135 |
-
crash_rate=crash_rate,
|
| 136 |
-
best_bpb=best_bpb,
|
| 137 |
-
last_improvement_at=last_improvement_at,
|
| 138 |
-
total_experiments=total,
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def should_explore(results_path: str, n: int = 10) -> bool:
|
| 143 |
-
"""Return True when no improvement has been seen in the last N experiments.
|
| 144 |
-
|
| 145 |
-
Args:
|
| 146 |
-
results_path: Path to results.tsv.
|
| 147 |
-
n: Look-back window for improvement check.
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
True if the research loop should try bolder mutations.
|
| 151 |
-
"""
|
| 152 |
-
state = diagnose(results_path, window=n, stuck_threshold=n)
|
| 153 |
-
return state.label in ("STUCK", "BROKEN")
|
|
|
|
| 1 |
+
"""Search strategy for HYDRA's meta-evolution loop.
|
| 2 |
+
|
| 3 |
+
Reads results.tsv and diagnoses the current research state as one of:
|
| 4 |
+
EXPLORING -- active improvement trend with diverse experiments
|
| 5 |
+
EXPLOITING -- narrowing in on a local optimum (low diversity)
|
| 6 |
+
STUCK -- no improvement for >= stuck_threshold experiments
|
| 7 |
+
BROKEN -- crash rate exceeds crash_threshold
|
| 8 |
+
"""
|
| 9 |
+
import csv
|
| 10 |
+
import os
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ResearchState:
|
| 16 |
+
"""Diagnosis of the current research trajectory.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
label: One of EXPLORING, EXPLOITING, STUCK, BROKEN.
|
| 20 |
+
trend_improving: True when the second half of the recent window is
|
| 21 |
+
better (lower BPB) than the first half.
|
| 22 |
+
experiment_diversity: Rough 0–1 score based on unique description
|
| 23 |
+
prefixes in the recent window.
|
| 24 |
+
crash_rate: Fraction of recent experiments that crashed.
|
| 25 |
+
best_bpb: Lowest val_bpb seen across all experiments.
|
| 26 |
+
last_improvement_at: Ordinal of the experiment that set best_bpb.
|
| 27 |
+
total_experiments: Total rows in results.tsv (excluding header).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
label: str
|
| 31 |
+
trend_improving: bool
|
| 32 |
+
experiment_diversity: float
|
| 33 |
+
crash_rate: float
|
| 34 |
+
best_bpb: float
|
| 35 |
+
last_improvement_at: int
|
| 36 |
+
total_experiments: int
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def diagnose(
|
| 40 |
+
results_path: str,
|
| 41 |
+
window: int = 20,
|
| 42 |
+
stuck_threshold: int = 10,
|
| 43 |
+
crash_threshold: float = 0.5,
|
| 44 |
+
) -> ResearchState:
|
| 45 |
+
"""Diagnose current research state from results.tsv.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
results_path: Path to the tab-separated results file.
|
| 49 |
+
window: Number of recent experiments to consider for trend/diversity.
|
| 50 |
+
stuck_threshold: Experiments without improvement before labelling STUCK.
|
| 51 |
+
crash_threshold: Crash fraction above which state becomes BROKEN.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
ResearchState with diagnosis label and supporting statistics.
|
| 55 |
+
"""
|
| 56 |
+
if not os.path.exists(results_path):
|
| 57 |
+
return ResearchState(
|
| 58 |
+
label="EXPLORING",
|
| 59 |
+
trend_improving=False,
|
| 60 |
+
experiment_diversity=0.0,
|
| 61 |
+
crash_rate=0.0,
|
| 62 |
+
best_bpb=float("inf"),
|
| 63 |
+
last_improvement_at=0,
|
| 64 |
+
total_experiments=0,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
rows: list[dict] = []
|
| 68 |
+
with open(results_path) as fh:
|
| 69 |
+
reader = csv.DictReader(fh, delimiter="\t")
|
| 70 |
+
for row in reader:
|
| 71 |
+
rows.append(row)
|
| 72 |
+
|
| 73 |
+
if not rows:
|
| 74 |
+
return ResearchState(
|
| 75 |
+
label="EXPLORING",
|
| 76 |
+
trend_improving=False,
|
| 77 |
+
experiment_diversity=0.0,
|
| 78 |
+
crash_rate=0.0,
|
| 79 |
+
best_bpb=float("inf"),
|
| 80 |
+
last_improvement_at=0,
|
| 81 |
+
total_experiments=0,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
total = len(rows)
|
| 85 |
+
recent = rows[-window:]
|
| 86 |
+
|
| 87 |
+
# Crash rate in the recent window.
|
| 88 |
+
crashes = sum(1 for r in recent if r.get("status") == "crash")
|
| 89 |
+
crash_rate = crashes / len(recent) if recent else 0.0
|
| 90 |
+
|
| 91 |
+
# Best BPB overall and which experiment achieved it.
|
| 92 |
+
best_bpb = float("inf")
|
| 93 |
+
last_improvement_at = 0
|
| 94 |
+
for i, row in enumerate(rows):
|
| 95 |
+
try:
|
| 96 |
+
bpb = float(row.get("val_bpb", "0") or "0")
|
| 97 |
+
except ValueError:
|
| 98 |
+
continue
|
| 99 |
+
if bpb > 0 and bpb < best_bpb:
|
| 100 |
+
best_bpb = bpb
|
| 101 |
+
last_improvement_at = i + 1
|
| 102 |
+
|
| 103 |
+
# Trend: is the second half of the recent window better than the first?
|
| 104 |
+
valid_bpbs = [
|
| 105 |
+
float(r.get("val_bpb", "0") or "0")
|
| 106 |
+
for r in recent
|
| 107 |
+
if float(r.get("val_bpb", "0") or "0") > 0
|
| 108 |
+
]
|
| 109 |
+
trend_improving = False
|
| 110 |
+
if len(valid_bpbs) >= 4:
|
| 111 |
+
mid = len(valid_bpbs) // 2
|
| 112 |
+
first_half_mean = sum(valid_bpbs[:mid]) / mid
|
| 113 |
+
second_half_mean = sum(valid_bpbs[mid:]) / (len(valid_bpbs) - mid)
|
| 114 |
+
trend_improving = second_half_mean < first_half_mean
|
| 115 |
+
|
| 116 |
+
# Diversity: fraction of unique description prefixes (first 20 chars).
|
| 117 |
+
descriptions = {r.get("description", "")[:20] for r in recent}
|
| 118 |
+
diversity = min(1.0, len(descriptions) / max(1, len(recent)))
|
| 119 |
+
|
| 120 |
+
# Classify state.
|
| 121 |
+
stale = total - last_improvement_at
|
| 122 |
+
if crash_rate > crash_threshold:
|
| 123 |
+
label = "BROKEN"
|
| 124 |
+
elif stale >= stuck_threshold:
|
| 125 |
+
label = "STUCK"
|
| 126 |
+
elif trend_improving and diversity > 0.3:
|
| 127 |
+
label = "EXPLORING"
|
| 128 |
+
else:
|
| 129 |
+
label = "EXPLOITING"
|
| 130 |
+
|
| 131 |
+
return ResearchState(
|
| 132 |
+
label=label,
|
| 133 |
+
trend_improving=trend_improving,
|
| 134 |
+
experiment_diversity=diversity,
|
| 135 |
+
crash_rate=crash_rate,
|
| 136 |
+
best_bpb=best_bpb,
|
| 137 |
+
last_improvement_at=last_improvement_at,
|
| 138 |
+
total_experiments=total,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def should_explore(results_path: str, n: int = 10) -> bool:
|
| 143 |
+
"""Return True when no improvement has been seen in the last N experiments.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
results_path: Path to results.tsv.
|
| 147 |
+
n: Look-back window for improvement check.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
True if the research loop should try bolder mutations.
|
| 151 |
+
"""
|
| 152 |
+
state = diagnose(results_path, window=n, stuck_threshold=n)
|
| 153 |
+
return state.label in ("STUCK", "BROKEN")
|
overlay/htm_rust/Cargo.lock
CHANGED
|
@@ -1,383 +1,383 @@
|
|
| 1 |
-
# This file is automatically @generated by Cargo.
|
| 2 |
-
# It is not intended for manual editing.
|
| 3 |
-
version = 4
|
| 4 |
-
|
| 5 |
-
[[package]]
|
| 6 |
-
name = "autocfg"
|
| 7 |
-
version = "1.5.0"
|
| 8 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 9 |
-
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
| 10 |
-
|
| 11 |
-
[[package]]
|
| 12 |
-
name = "cfg-if"
|
| 13 |
-
version = "1.0.4"
|
| 14 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 15 |
-
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
| 16 |
-
|
| 17 |
-
[[package]]
|
| 18 |
-
name = "cudarc"
|
| 19 |
-
version = "0.12.1"
|
| 20 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 21 |
-
checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
|
| 22 |
-
dependencies = [
|
| 23 |
-
"libloading",
|
| 24 |
-
]
|
| 25 |
-
|
| 26 |
-
[[package]]
|
| 27 |
-
name = "getrandom"
|
| 28 |
-
version = "0.2.17"
|
| 29 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 30 |
-
checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
|
| 31 |
-
dependencies = [
|
| 32 |
-
"cfg-if",
|
| 33 |
-
"libc",
|
| 34 |
-
"wasi",
|
| 35 |
-
]
|
| 36 |
-
|
| 37 |
-
[[package]]
|
| 38 |
-
name = "heck"
|
| 39 |
-
version = "0.5.0"
|
| 40 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 41 |
-
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
| 42 |
-
|
| 43 |
-
[[package]]
|
| 44 |
-
name = "htm_rust"
|
| 45 |
-
version = "0.1.0"
|
| 46 |
-
dependencies = [
|
| 47 |
-
"cudarc",
|
| 48 |
-
"ndarray",
|
| 49 |
-
"numpy",
|
| 50 |
-
"pyo3",
|
| 51 |
-
"rand",
|
| 52 |
-
"rand_xoshiro",
|
| 53 |
-
]
|
| 54 |
-
|
| 55 |
-
[[package]]
|
| 56 |
-
name = "indoc"
|
| 57 |
-
version = "2.0.7"
|
| 58 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 59 |
-
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
|
| 60 |
-
dependencies = [
|
| 61 |
-
"rustversion",
|
| 62 |
-
]
|
| 63 |
-
|
| 64 |
-
[[package]]
|
| 65 |
-
name = "libc"
|
| 66 |
-
version = "0.2.185"
|
| 67 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 68 |
-
checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f"
|
| 69 |
-
|
| 70 |
-
[[package]]
|
| 71 |
-
name = "libloading"
|
| 72 |
-
version = "0.8.9"
|
| 73 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 74 |
-
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
| 75 |
-
dependencies = [
|
| 76 |
-
"cfg-if",
|
| 77 |
-
"windows-link",
|
| 78 |
-
]
|
| 79 |
-
|
| 80 |
-
[[package]]
|
| 81 |
-
name = "matrixmultiply"
|
| 82 |
-
version = "0.3.10"
|
| 83 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 84 |
-
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
|
| 85 |
-
dependencies = [
|
| 86 |
-
"autocfg",
|
| 87 |
-
"rawpointer",
|
| 88 |
-
]
|
| 89 |
-
|
| 90 |
-
[[package]]
|
| 91 |
-
name = "memoffset"
|
| 92 |
-
version = "0.9.1"
|
| 93 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 94 |
-
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
| 95 |
-
dependencies = [
|
| 96 |
-
"autocfg",
|
| 97 |
-
]
|
| 98 |
-
|
| 99 |
-
[[package]]
|
| 100 |
-
name = "ndarray"
|
| 101 |
-
version = "0.16.1"
|
| 102 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 103 |
-
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
| 104 |
-
dependencies = [
|
| 105 |
-
"matrixmultiply",
|
| 106 |
-
"num-complex",
|
| 107 |
-
"num-integer",
|
| 108 |
-
"num-traits",
|
| 109 |
-
"portable-atomic",
|
| 110 |
-
"portable-atomic-util",
|
| 111 |
-
"rawpointer",
|
| 112 |
-
]
|
| 113 |
-
|
| 114 |
-
[[package]]
|
| 115 |
-
name = "num-complex"
|
| 116 |
-
version = "0.4.6"
|
| 117 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 118 |
-
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
| 119 |
-
dependencies = [
|
| 120 |
-
"num-traits",
|
| 121 |
-
]
|
| 122 |
-
|
| 123 |
-
[[package]]
|
| 124 |
-
name = "num-integer"
|
| 125 |
-
version = "0.1.46"
|
| 126 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 127 |
-
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
| 128 |
-
dependencies = [
|
| 129 |
-
"num-traits",
|
| 130 |
-
]
|
| 131 |
-
|
| 132 |
-
[[package]]
|
| 133 |
-
name = "num-traits"
|
| 134 |
-
version = "0.2.19"
|
| 135 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 136 |
-
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
| 137 |
-
dependencies = [
|
| 138 |
-
"autocfg",
|
| 139 |
-
]
|
| 140 |
-
|
| 141 |
-
[[package]]
|
| 142 |
-
name = "numpy"
|
| 143 |
-
version = "0.22.1"
|
| 144 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 145 |
-
checksum = "edb929bc0da91a4d85ed6c0a84deaa53d411abfb387fc271124f91bf6b89f14e"
|
| 146 |
-
dependencies = [
|
| 147 |
-
"libc",
|
| 148 |
-
"ndarray",
|
| 149 |
-
"num-complex",
|
| 150 |
-
"num-integer",
|
| 151 |
-
"num-traits",
|
| 152 |
-
"pyo3",
|
| 153 |
-
"rustc-hash",
|
| 154 |
-
]
|
| 155 |
-
|
| 156 |
-
[[package]]
|
| 157 |
-
name = "once_cell"
|
| 158 |
-
version = "1.21.4"
|
| 159 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 160 |
-
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
|
| 161 |
-
|
| 162 |
-
[[package]]
|
| 163 |
-
name = "portable-atomic"
|
| 164 |
-
version = "1.13.1"
|
| 165 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 166 |
-
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
| 167 |
-
|
| 168 |
-
[[package]]
|
| 169 |
-
name = "portable-atomic-util"
|
| 170 |
-
version = "0.2.6"
|
| 171 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 172 |
-
checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
|
| 173 |
-
dependencies = [
|
| 174 |
-
"portable-atomic",
|
| 175 |
-
]
|
| 176 |
-
|
| 177 |
-
[[package]]
|
| 178 |
-
name = "ppv-lite86"
|
| 179 |
-
version = "0.2.21"
|
| 180 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 181 |
-
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
|
| 182 |
-
dependencies = [
|
| 183 |
-
"zerocopy",
|
| 184 |
-
]
|
| 185 |
-
|
| 186 |
-
[[package]]
|
| 187 |
-
name = "proc-macro2"
|
| 188 |
-
version = "1.0.106"
|
| 189 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 190 |
-
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
| 191 |
-
dependencies = [
|
| 192 |
-
"unicode-ident",
|
| 193 |
-
]
|
| 194 |
-
|
| 195 |
-
[[package]]
|
| 196 |
-
name = "pyo3"
|
| 197 |
-
version = "0.22.6"
|
| 198 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 199 |
-
checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
|
| 200 |
-
dependencies = [
|
| 201 |
-
"cfg-if",
|
| 202 |
-
"indoc",
|
| 203 |
-
"libc",
|
| 204 |
-
"memoffset",
|
| 205 |
-
"once_cell",
|
| 206 |
-
"portable-atomic",
|
| 207 |
-
"pyo3-build-config",
|
| 208 |
-
"pyo3-ffi",
|
| 209 |
-
"pyo3-macros",
|
| 210 |
-
"unindent",
|
| 211 |
-
]
|
| 212 |
-
|
| 213 |
-
[[package]]
|
| 214 |
-
name = "pyo3-build-config"
|
| 215 |
-
version = "0.22.6"
|
| 216 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 217 |
-
checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
|
| 218 |
-
dependencies = [
|
| 219 |
-
"once_cell",
|
| 220 |
-
"target-lexicon",
|
| 221 |
-
]
|
| 222 |
-
|
| 223 |
-
[[package]]
|
| 224 |
-
name = "pyo3-ffi"
|
| 225 |
-
version = "0.22.6"
|
| 226 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 227 |
-
checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
|
| 228 |
-
dependencies = [
|
| 229 |
-
"libc",
|
| 230 |
-
"pyo3-build-config",
|
| 231 |
-
]
|
| 232 |
-
|
| 233 |
-
[[package]]
|
| 234 |
-
name = "pyo3-macros"
|
| 235 |
-
version = "0.22.6"
|
| 236 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 237 |
-
checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
|
| 238 |
-
dependencies = [
|
| 239 |
-
"proc-macro2",
|
| 240 |
-
"pyo3-macros-backend",
|
| 241 |
-
"quote",
|
| 242 |
-
"syn",
|
| 243 |
-
]
|
| 244 |
-
|
| 245 |
-
[[package]]
|
| 246 |
-
name = "pyo3-macros-backend"
|
| 247 |
-
version = "0.22.6"
|
| 248 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 249 |
-
checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
|
| 250 |
-
dependencies = [
|
| 251 |
-
"heck",
|
| 252 |
-
"proc-macro2",
|
| 253 |
-
"pyo3-build-config",
|
| 254 |
-
"quote",
|
| 255 |
-
"syn",
|
| 256 |
-
]
|
| 257 |
-
|
| 258 |
-
[[package]]
|
| 259 |
-
name = "quote"
|
| 260 |
-
version = "1.0.45"
|
| 261 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 262 |
-
checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
|
| 263 |
-
dependencies = [
|
| 264 |
-
"proc-macro2",
|
| 265 |
-
]
|
| 266 |
-
|
| 267 |
-
[[package]]
|
| 268 |
-
name = "rand"
|
| 269 |
-
version = "0.8.5"
|
| 270 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 271 |
-
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
| 272 |
-
dependencies = [
|
| 273 |
-
"libc",
|
| 274 |
-
"rand_chacha",
|
| 275 |
-
"rand_core",
|
| 276 |
-
]
|
| 277 |
-
|
| 278 |
-
[[package]]
|
| 279 |
-
name = "rand_chacha"
|
| 280 |
-
version = "0.3.1"
|
| 281 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 282 |
-
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
| 283 |
-
dependencies = [
|
| 284 |
-
"ppv-lite86",
|
| 285 |
-
"rand_core",
|
| 286 |
-
]
|
| 287 |
-
|
| 288 |
-
[[package]]
|
| 289 |
-
name = "rand_core"
|
| 290 |
-
version = "0.6.4"
|
| 291 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 292 |
-
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
| 293 |
-
dependencies = [
|
| 294 |
-
"getrandom",
|
| 295 |
-
]
|
| 296 |
-
|
| 297 |
-
[[package]]
|
| 298 |
-
name = "rand_xoshiro"
|
| 299 |
-
version = "0.6.0"
|
| 300 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 301 |
-
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
|
| 302 |
-
dependencies = [
|
| 303 |
-
"rand_core",
|
| 304 |
-
]
|
| 305 |
-
|
| 306 |
-
[[package]]
|
| 307 |
-
name = "rawpointer"
|
| 308 |
-
version = "0.2.1"
|
| 309 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 310 |
-
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
| 311 |
-
|
| 312 |
-
[[package]]
|
| 313 |
-
name = "rustc-hash"
|
| 314 |
-
version = "1.1.0"
|
| 315 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 316 |
-
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
| 317 |
-
|
| 318 |
-
[[package]]
|
| 319 |
-
name = "rustversion"
|
| 320 |
-
version = "1.0.22"
|
| 321 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 322 |
-
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
| 323 |
-
|
| 324 |
-
[[package]]
|
| 325 |
-
name = "syn"
|
| 326 |
-
version = "2.0.117"
|
| 327 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 328 |
-
checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
|
| 329 |
-
dependencies = [
|
| 330 |
-
"proc-macro2",
|
| 331 |
-
"quote",
|
| 332 |
-
"unicode-ident",
|
| 333 |
-
]
|
| 334 |
-
|
| 335 |
-
[[package]]
|
| 336 |
-
name = "target-lexicon"
|
| 337 |
-
version = "0.12.16"
|
| 338 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 339 |
-
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
| 340 |
-
|
| 341 |
-
[[package]]
|
| 342 |
-
name = "unicode-ident"
|
| 343 |
-
version = "1.0.24"
|
| 344 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 345 |
-
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
| 346 |
-
|
| 347 |
-
[[package]]
|
| 348 |
-
name = "unindent"
|
| 349 |
-
version = "0.2.4"
|
| 350 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 351 |
-
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
|
| 352 |
-
|
| 353 |
-
[[package]]
|
| 354 |
-
name = "wasi"
|
| 355 |
-
version = "0.11.1+wasi-snapshot-preview1"
|
| 356 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 357 |
-
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
| 358 |
-
|
| 359 |
-
[[package]]
|
| 360 |
-
name = "windows-link"
|
| 361 |
-
version = "0.2.1"
|
| 362 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 363 |
-
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
| 364 |
-
|
| 365 |
-
[[package]]
|
| 366 |
-
name = "zerocopy"
|
| 367 |
-
version = "0.8.48"
|
| 368 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 369 |
-
checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9"
|
| 370 |
-
dependencies = [
|
| 371 |
-
"zerocopy-derive",
|
| 372 |
-
]
|
| 373 |
-
|
| 374 |
-
[[package]]
|
| 375 |
-
name = "zerocopy-derive"
|
| 376 |
-
version = "0.8.48"
|
| 377 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 378 |
-
checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4"
|
| 379 |
-
dependencies = [
|
| 380 |
-
"proc-macro2",
|
| 381 |
-
"quote",
|
| 382 |
-
"syn",
|
| 383 |
-
]
|
|
|
|
| 1 |
+
# This file is automatically @generated by Cargo.
|
| 2 |
+
# It is not intended for manual editing.
|
| 3 |
+
version = 4
|
| 4 |
+
|
| 5 |
+
[[package]]
|
| 6 |
+
name = "autocfg"
|
| 7 |
+
version = "1.5.0"
|
| 8 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 9 |
+
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
| 10 |
+
|
| 11 |
+
[[package]]
|
| 12 |
+
name = "cfg-if"
|
| 13 |
+
version = "1.0.4"
|
| 14 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 15 |
+
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
| 16 |
+
|
| 17 |
+
[[package]]
|
| 18 |
+
name = "cudarc"
|
| 19 |
+
version = "0.12.1"
|
| 20 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 21 |
+
checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384"
|
| 22 |
+
dependencies = [
|
| 23 |
+
"libloading",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[[package]]
|
| 27 |
+
name = "getrandom"
|
| 28 |
+
version = "0.2.17"
|
| 29 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 30 |
+
checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
|
| 31 |
+
dependencies = [
|
| 32 |
+
"cfg-if",
|
| 33 |
+
"libc",
|
| 34 |
+
"wasi",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[[package]]
|
| 38 |
+
name = "heck"
|
| 39 |
+
version = "0.5.0"
|
| 40 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 41 |
+
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
| 42 |
+
|
| 43 |
+
[[package]]
|
| 44 |
+
name = "htm_rust"
|
| 45 |
+
version = "0.1.0"
|
| 46 |
+
dependencies = [
|
| 47 |
+
"cudarc",
|
| 48 |
+
"ndarray",
|
| 49 |
+
"numpy",
|
| 50 |
+
"pyo3",
|
| 51 |
+
"rand",
|
| 52 |
+
"rand_xoshiro",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
[[package]]
|
| 56 |
+
name = "indoc"
|
| 57 |
+
version = "2.0.7"
|
| 58 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 59 |
+
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
|
| 60 |
+
dependencies = [
|
| 61 |
+
"rustversion",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
[[package]]
|
| 65 |
+
name = "libc"
|
| 66 |
+
version = "0.2.185"
|
| 67 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 68 |
+
checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f"
|
| 69 |
+
|
| 70 |
+
[[package]]
|
| 71 |
+
name = "libloading"
|
| 72 |
+
version = "0.8.9"
|
| 73 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 74 |
+
checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
|
| 75 |
+
dependencies = [
|
| 76 |
+
"cfg-if",
|
| 77 |
+
"windows-link",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
[[package]]
|
| 81 |
+
name = "matrixmultiply"
|
| 82 |
+
version = "0.3.10"
|
| 83 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 84 |
+
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
|
| 85 |
+
dependencies = [
|
| 86 |
+
"autocfg",
|
| 87 |
+
"rawpointer",
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
[[package]]
|
| 91 |
+
name = "memoffset"
|
| 92 |
+
version = "0.9.1"
|
| 93 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 94 |
+
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
| 95 |
+
dependencies = [
|
| 96 |
+
"autocfg",
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
[[package]]
|
| 100 |
+
name = "ndarray"
|
| 101 |
+
version = "0.16.1"
|
| 102 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 103 |
+
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
| 104 |
+
dependencies = [
|
| 105 |
+
"matrixmultiply",
|
| 106 |
+
"num-complex",
|
| 107 |
+
"num-integer",
|
| 108 |
+
"num-traits",
|
| 109 |
+
"portable-atomic",
|
| 110 |
+
"portable-atomic-util",
|
| 111 |
+
"rawpointer",
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
[[package]]
|
| 115 |
+
name = "num-complex"
|
| 116 |
+
version = "0.4.6"
|
| 117 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 118 |
+
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
| 119 |
+
dependencies = [
|
| 120 |
+
"num-traits",
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
[[package]]
|
| 124 |
+
name = "num-integer"
|
| 125 |
+
version = "0.1.46"
|
| 126 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 127 |
+
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
| 128 |
+
dependencies = [
|
| 129 |
+
"num-traits",
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
[[package]]
|
| 133 |
+
name = "num-traits"
|
| 134 |
+
version = "0.2.19"
|
| 135 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 136 |
+
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
| 137 |
+
dependencies = [
|
| 138 |
+
"autocfg",
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
[[package]]
|
| 142 |
+
name = "numpy"
|
| 143 |
+
version = "0.22.1"
|
| 144 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 145 |
+
checksum = "edb929bc0da91a4d85ed6c0a84deaa53d411abfb387fc271124f91bf6b89f14e"
|
| 146 |
+
dependencies = [
|
| 147 |
+
"libc",
|
| 148 |
+
"ndarray",
|
| 149 |
+
"num-complex",
|
| 150 |
+
"num-integer",
|
| 151 |
+
"num-traits",
|
| 152 |
+
"pyo3",
|
| 153 |
+
"rustc-hash",
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
[[package]]
|
| 157 |
+
name = "once_cell"
|
| 158 |
+
version = "1.21.4"
|
| 159 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 160 |
+
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
|
| 161 |
+
|
| 162 |
+
[[package]]
|
| 163 |
+
name = "portable-atomic"
|
| 164 |
+
version = "1.13.1"
|
| 165 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 166 |
+
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
| 167 |
+
|
| 168 |
+
[[package]]
|
| 169 |
+
name = "portable-atomic-util"
|
| 170 |
+
version = "0.2.6"
|
| 171 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 172 |
+
checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
|
| 173 |
+
dependencies = [
|
| 174 |
+
"portable-atomic",
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
[[package]]
|
| 178 |
+
name = "ppv-lite86"
|
| 179 |
+
version = "0.2.21"
|
| 180 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 181 |
+
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
|
| 182 |
+
dependencies = [
|
| 183 |
+
"zerocopy",
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
[[package]]
|
| 187 |
+
name = "proc-macro2"
|
| 188 |
+
version = "1.0.106"
|
| 189 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 190 |
+
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
| 191 |
+
dependencies = [
|
| 192 |
+
"unicode-ident",
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
[[package]]
|
| 196 |
+
name = "pyo3"
|
| 197 |
+
version = "0.22.6"
|
| 198 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 199 |
+
checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
|
| 200 |
+
dependencies = [
|
| 201 |
+
"cfg-if",
|
| 202 |
+
"indoc",
|
| 203 |
+
"libc",
|
| 204 |
+
"memoffset",
|
| 205 |
+
"once_cell",
|
| 206 |
+
"portable-atomic",
|
| 207 |
+
"pyo3-build-config",
|
| 208 |
+
"pyo3-ffi",
|
| 209 |
+
"pyo3-macros",
|
| 210 |
+
"unindent",
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
[[package]]
|
| 214 |
+
name = "pyo3-build-config"
|
| 215 |
+
version = "0.22.6"
|
| 216 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 217 |
+
checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
|
| 218 |
+
dependencies = [
|
| 219 |
+
"once_cell",
|
| 220 |
+
"target-lexicon",
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
[[package]]
|
| 224 |
+
name = "pyo3-ffi"
|
| 225 |
+
version = "0.22.6"
|
| 226 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 227 |
+
checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
|
| 228 |
+
dependencies = [
|
| 229 |
+
"libc",
|
| 230 |
+
"pyo3-build-config",
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
[[package]]
|
| 234 |
+
name = "pyo3-macros"
|
| 235 |
+
version = "0.22.6"
|
| 236 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 237 |
+
checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
|
| 238 |
+
dependencies = [
|
| 239 |
+
"proc-macro2",
|
| 240 |
+
"pyo3-macros-backend",
|
| 241 |
+
"quote",
|
| 242 |
+
"syn",
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
[[package]]
|
| 246 |
+
name = "pyo3-macros-backend"
|
| 247 |
+
version = "0.22.6"
|
| 248 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 249 |
+
checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
|
| 250 |
+
dependencies = [
|
| 251 |
+
"heck",
|
| 252 |
+
"proc-macro2",
|
| 253 |
+
"pyo3-build-config",
|
| 254 |
+
"quote",
|
| 255 |
+
"syn",
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
[[package]]
|
| 259 |
+
name = "quote"
|
| 260 |
+
version = "1.0.45"
|
| 261 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 262 |
+
checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
|
| 263 |
+
dependencies = [
|
| 264 |
+
"proc-macro2",
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
[[package]]
|
| 268 |
+
name = "rand"
|
| 269 |
+
version = "0.8.5"
|
| 270 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 271 |
+
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
| 272 |
+
dependencies = [
|
| 273 |
+
"libc",
|
| 274 |
+
"rand_chacha",
|
| 275 |
+
"rand_core",
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
[[package]]
|
| 279 |
+
name = "rand_chacha"
|
| 280 |
+
version = "0.3.1"
|
| 281 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 282 |
+
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
| 283 |
+
dependencies = [
|
| 284 |
+
"ppv-lite86",
|
| 285 |
+
"rand_core",
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
[[package]]
|
| 289 |
+
name = "rand_core"
|
| 290 |
+
version = "0.6.4"
|
| 291 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 292 |
+
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
| 293 |
+
dependencies = [
|
| 294 |
+
"getrandom",
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
[[package]]
|
| 298 |
+
name = "rand_xoshiro"
|
| 299 |
+
version = "0.6.0"
|
| 300 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 301 |
+
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
|
| 302 |
+
dependencies = [
|
| 303 |
+
"rand_core",
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
[[package]]
|
| 307 |
+
name = "rawpointer"
|
| 308 |
+
version = "0.2.1"
|
| 309 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 310 |
+
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
| 311 |
+
|
| 312 |
+
[[package]]
|
| 313 |
+
name = "rustc-hash"
|
| 314 |
+
version = "1.1.0"
|
| 315 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 316 |
+
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
| 317 |
+
|
| 318 |
+
[[package]]
|
| 319 |
+
name = "rustversion"
|
| 320 |
+
version = "1.0.22"
|
| 321 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 322 |
+
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
| 323 |
+
|
| 324 |
+
[[package]]
|
| 325 |
+
name = "syn"
|
| 326 |
+
version = "2.0.117"
|
| 327 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 328 |
+
checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
|
| 329 |
+
dependencies = [
|
| 330 |
+
"proc-macro2",
|
| 331 |
+
"quote",
|
| 332 |
+
"unicode-ident",
|
| 333 |
+
]
|
| 334 |
+
|
| 335 |
+
[[package]]
|
| 336 |
+
name = "target-lexicon"
|
| 337 |
+
version = "0.12.16"
|
| 338 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 339 |
+
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
| 340 |
+
|
| 341 |
+
[[package]]
|
| 342 |
+
name = "unicode-ident"
|
| 343 |
+
version = "1.0.24"
|
| 344 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 345 |
+
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
| 346 |
+
|
| 347 |
+
[[package]]
|
| 348 |
+
name = "unindent"
|
| 349 |
+
version = "0.2.4"
|
| 350 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 351 |
+
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
|
| 352 |
+
|
| 353 |
+
[[package]]
|
| 354 |
+
name = "wasi"
|
| 355 |
+
version = "0.11.1+wasi-snapshot-preview1"
|
| 356 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 357 |
+
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
| 358 |
+
|
| 359 |
+
[[package]]
|
| 360 |
+
name = "windows-link"
|
| 361 |
+
version = "0.2.1"
|
| 362 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 363 |
+
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
| 364 |
+
|
| 365 |
+
[[package]]
|
| 366 |
+
name = "zerocopy"
|
| 367 |
+
version = "0.8.48"
|
| 368 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 369 |
+
checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9"
|
| 370 |
+
dependencies = [
|
| 371 |
+
"zerocopy-derive",
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
[[package]]
|
| 375 |
+
name = "zerocopy-derive"
|
| 376 |
+
version = "0.8.48"
|
| 377 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 378 |
+
checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4"
|
| 379 |
+
dependencies = [
|
| 380 |
+
"proc-macro2",
|
| 381 |
+
"quote",
|
| 382 |
+
"syn",
|
| 383 |
+
]
|
overlay/htm_rust/Cargo.toml
CHANGED
|
@@ -1,37 +1,37 @@
|
|
| 1 |
-
[package]
|
| 2 |
-
name = "htm_rust"
|
| 3 |
-
version = "0.1.0"
|
| 4 |
-
edition = "2021"
|
| 5 |
-
authors = ["Feather/HYDRA"]
|
| 6 |
-
description = "Numenta BAMI-spec Hierarchical Temporal Memory (Spatial Pooler + Temporal Memory) with pyo3 bindings"
|
| 7 |
-
license = "MIT"
|
| 8 |
-
|
| 9 |
-
[lib]
|
| 10 |
-
name = "htm_rust"
|
| 11 |
-
crate-type = ["cdylib", "rlib"]
|
| 12 |
-
|
| 13 |
-
[dependencies]
|
| 14 |
-
pyo3 = { version = "0.22", features = ["extension-module"] }
|
| 15 |
-
numpy = "0.22"
|
| 16 |
-
ndarray = "0.16"
|
| 17 |
-
rand = "0.8"
|
| 18 |
-
rand_xoshiro = "0.6"
|
| 19 |
-
# cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
|
| 20 |
-
# Kernels are embedded as PTX and JIT-compiled at runtime.
|
| 21 |
-
cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
|
| 22 |
-
|
| 23 |
-
[build-dependencies]
|
| 24 |
-
# Only required when building with --features gpu. We shell to nvcc directly
|
| 25 |
-
# so we don't need cc's cuda support (which drags in extra deps).
|
| 26 |
-
|
| 27 |
-
[features]
|
| 28 |
-
default = []
|
| 29 |
-
# `gpu` adds the HTMRegionGPU class, compiles .cu kernels to PTX at build time,
|
| 30 |
-
# and links cudarc. Without this feature the crate is pure-CPU and has no
|
| 31 |
-
# CUDA dependency at build or run time.
|
| 32 |
-
gpu = ["cudarc"]
|
| 33 |
-
|
| 34 |
-
[profile.release]
|
| 35 |
-
opt-level = 3
|
| 36 |
-
lto = "thin"
|
| 37 |
-
codegen-units = 1
|
|
|
|
| 1 |
+
[package]
|
| 2 |
+
name = "htm_rust"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
edition = "2021"
|
| 5 |
+
authors = ["Feather/HYDRA"]
|
| 6 |
+
description = "Numenta BAMI-spec Hierarchical Temporal Memory (Spatial Pooler + Temporal Memory) with pyo3 bindings"
|
| 7 |
+
license = "MIT"
|
| 8 |
+
|
| 9 |
+
[lib]
|
| 10 |
+
name = "htm_rust"
|
| 11 |
+
crate-type = ["cdylib", "rlib"]
|
| 12 |
+
|
| 13 |
+
[dependencies]
|
| 14 |
+
pyo3 = { version = "0.22", features = ["extension-module"] }
|
| 15 |
+
numpy = "0.22"
|
| 16 |
+
ndarray = "0.16"
|
| 17 |
+
rand = "0.8"
|
| 18 |
+
rand_xoshiro = "0.6"
|
| 19 |
+
# cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
|
| 20 |
+
# Kernels are embedded as PTX and JIT-compiled at runtime.
|
| 21 |
+
cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
|
| 22 |
+
|
| 23 |
+
[build-dependencies]
|
| 24 |
+
# Only required when building with --features gpu. We shell to nvcc directly
|
| 25 |
+
# so we don't need cc's cuda support (which drags in extra deps).
|
| 26 |
+
|
| 27 |
+
[features]
|
| 28 |
+
default = []
|
| 29 |
+
# `gpu` adds the HTMRegionGPU class, compiles .cu kernels to PTX at build time,
|
| 30 |
+
# and links cudarc. Without this feature the crate is pure-CPU and has no
|
| 31 |
+
# CUDA dependency at build or run time.
|
| 32 |
+
gpu = ["cudarc"]
|
| 33 |
+
|
| 34 |
+
[profile.release]
|
| 35 |
+
opt-level = 3
|
| 36 |
+
lto = "thin"
|
| 37 |
+
codegen-units = 1
|
overlay/htm_rust/build.rs
CHANGED
|
@@ -1,160 +1,168 @@
|
|
| 1 |
-
//! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
|
| 2 |
-
//! is enabled. PTX files are embedded into the final Rust binary via
|
| 3 |
-
//! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
|
| 4 |
-
//!
|
| 5 |
-
//! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
|
| 6 |
-
//! toolchain dependency.
|
| 7 |
-
//!
|
| 8 |
-
//! nvcc lookup order:
|
| 9 |
-
//! 1. $NVCC env var
|
| 10 |
-
//! 2. `nvcc` on PATH
|
| 11 |
-
//! 3. `/usr/local/cuda-12.1/bin/nvcc`
|
| 12 |
-
//! 4. `/usr/local/cuda/bin/nvcc`
|
| 13 |
-
//!
|
| 14 |
-
//!
|
| 15 |
-
|
| 16 |
-
use std::env;
|
| 17 |
-
use std::path::PathBuf;
|
| 18 |
-
use std::process::Command;
|
| 19 |
-
|
| 20 |
-
fn main() {
|
| 21 |
-
// Re-run whenever we edit the build script or any kernel source.
|
| 22 |
-
println!("cargo:rerun-if-changed=build.rs");
|
| 23 |
-
|
| 24 |
-
let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
|
| 25 |
-
if !gpu {
|
| 26 |
-
return;
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
-
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
|
| 30 |
-
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "
|
| 31 |
-
|
| 32 |
-
// Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
|
| 33 |
-
let base_kernels: &[&str] = &[
|
| 34 |
-
"sp_overlap",
|
| 35 |
-
"sp_topk",
|
| 36 |
-
"sp_learn",
|
| 37 |
-
"sp_duty",
|
| 38 |
-
"sp_boost_fused",
|
| 39 |
-
"tm_predict",
|
| 40 |
-
"tm_activate",
|
| 41 |
-
"tm_learn",
|
| 42 |
-
"tm_punish",
|
| 43 |
-
"tm_grow",
|
| 44 |
-
"tm_anomaly",
|
| 45 |
-
"tm_reset",
|
| 46 |
-
];
|
| 47 |
-
|
| 48 |
-
// htm_fused_step now compiles for ALL architectures (sm_80+).
|
| 49 |
-
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
|
| 50 |
-
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
|
| 51 |
-
// with grid.sync() for cross-block synchronization (cooperative launch).
|
| 52 |
-
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
|
| 53 |
-
|
| 54 |
-
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 55 |
-
for k in &kernels {
|
| 56 |
-
let src = kernels_dir.join(format!("{k}.cu"));
|
| 57 |
-
println!("cargo:rerun-if-changed={}", src.display());
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
let nvcc = find_nvcc();
|
| 62 |
-
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
| 63 |
-
println!("cargo:warning=htm_rust: target arch = {arch}");
|
| 64 |
-
|
| 65 |
-
// Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
|
| 66 |
-
let host_compiler = env::var("HTM_CUDA_CCBIN")
|
| 67 |
-
.ok()
|
| 68 |
-
.or_else(|| {
|
| 69 |
-
for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
|
| 70 |
-
if std::path::Path::new(cand).exists() {
|
| 71 |
-
return Some(cand.to_string());
|
| 72 |
-
}
|
| 73 |
-
}
|
| 74 |
-
None
|
| 75 |
-
});
|
| 76 |
-
|
| 77 |
-
// Optionally patch the emitted PTX `.version` header down to match an
|
| 78 |
-
// older driver. Useful when the system driver (e.g. on WSL2) is older
|
| 79 |
-
// than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
|
| 80 |
-
let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
|
| 81 |
-
|
| 82 |
-
for k in kernels {
|
| 83 |
-
let src = kernels_dir.join(format!("{k}.cu"));
|
| 84 |
-
let ptx = out_dir.join(format!("{k}.ptx"));
|
| 85 |
-
if !src.exists() {
|
| 86 |
-
panic!("missing kernel source: {}", src.display());
|
| 87 |
-
}
|
| 88 |
-
let mut cmd = Command::new(&nvcc);
|
| 89 |
-
// Note: `--use_fast_math` breaks bit-parity with host `expf`, which
|
| 90 |
-
// in turn flips boost tie-breaks in SP learning. We accept the tiny
|
| 91 |
-
// perf loss for correctness; the hot overlap kernel has no transcendentals.
|
| 92 |
-
cmd.args([
|
| 93 |
-
"--ptx",
|
| 94 |
-
"-O3",
|
| 95 |
-
"-rdc=true",
|
| 96 |
-
"-arch",
|
| 97 |
-
&arch,
|
| 98 |
-
]);
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
.
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
.
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
"nvcc
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
|
| 2 |
+
//! is enabled. PTX files are embedded into the final Rust binary via
|
| 3 |
+
//! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
|
| 4 |
+
//!
|
| 5 |
+
//! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
|
| 6 |
+
//! toolchain dependency.
|
| 7 |
+
//!
|
| 8 |
+
//! nvcc lookup order:
|
| 9 |
+
//! 1. $NVCC env var
|
| 10 |
+
//! 2. `nvcc` on PATH
|
| 11 |
+
//! 3. `/usr/local/cuda-12.1/bin/nvcc`
|
| 12 |
+
//! 4. `/usr/local/cuda/bin/nvcc`
|
| 13 |
+
//!
|
| 14 |
+
//! Default target: sm_86 (Ampere A10G / RTX 30xx). Override with $HTM_CUDA_ARCH (e.g. sm_90a for H200).
|
| 15 |
+
|
| 16 |
+
use std::env;
|
| 17 |
+
use std::path::PathBuf;
|
| 18 |
+
use std::process::Command;
|
| 19 |
+
|
| 20 |
+
fn main() {
|
| 21 |
+
// Re-run whenever we edit the build script or any kernel source.
|
| 22 |
+
println!("cargo:rerun-if-changed=build.rs");
|
| 23 |
+
|
| 24 |
+
let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
|
| 25 |
+
if !gpu {
|
| 26 |
+
return;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
|
| 30 |
+
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
|
| 31 |
+
|
| 32 |
+
// Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
|
| 33 |
+
let base_kernels: &[&str] = &[
|
| 34 |
+
"sp_overlap",
|
| 35 |
+
"sp_topk",
|
| 36 |
+
"sp_learn",
|
| 37 |
+
"sp_duty",
|
| 38 |
+
"sp_boost_fused",
|
| 39 |
+
"tm_predict",
|
| 40 |
+
"tm_activate",
|
| 41 |
+
"tm_learn",
|
| 42 |
+
"tm_punish",
|
| 43 |
+
"tm_grow",
|
| 44 |
+
"tm_anomaly",
|
| 45 |
+
"tm_reset",
|
| 46 |
+
];
|
| 47 |
+
|
| 48 |
+
// htm_fused_step now compiles for ALL architectures (sm_80+).
|
| 49 |
+
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
|
| 50 |
+
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
|
| 51 |
+
// with grid.sync() for cross-block synchronization (cooperative launch).
|
| 52 |
+
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
|
| 53 |
+
|
| 54 |
+
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 55 |
+
for k in &kernels {
|
| 56 |
+
let src = kernels_dir.join(format!("{k}.cu"));
|
| 57 |
+
println!("cargo:rerun-if-changed={}", src.display());
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
let nvcc = find_nvcc();
|
| 62 |
+
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
| 63 |
+
println!("cargo:warning=htm_rust: target arch = {arch}");
|
| 64 |
+
|
| 65 |
+
// Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
|
| 66 |
+
let host_compiler = env::var("HTM_CUDA_CCBIN")
|
| 67 |
+
.ok()
|
| 68 |
+
.or_else(|| {
|
| 69 |
+
for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
|
| 70 |
+
if std::path::Path::new(cand).exists() {
|
| 71 |
+
return Some(cand.to_string());
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
None
|
| 75 |
+
});
|
| 76 |
+
|
| 77 |
+
// Optionally patch the emitted PTX `.version` header down to match an
|
| 78 |
+
// older driver. Useful when the system driver (e.g. on WSL2) is older
|
| 79 |
+
// than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
|
| 80 |
+
let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
|
| 81 |
+
|
| 82 |
+
for k in kernels {
|
| 83 |
+
let src = kernels_dir.join(format!("{k}.cu"));
|
| 84 |
+
let ptx = out_dir.join(format!("{k}.ptx"));
|
| 85 |
+
if !src.exists() {
|
| 86 |
+
panic!("missing kernel source: {}", src.display());
|
| 87 |
+
}
|
| 88 |
+
let mut cmd = Command::new(&nvcc);
|
| 89 |
+
// Note: `--use_fast_math` breaks bit-parity with host `expf`, which
|
| 90 |
+
// in turn flips boost tie-breaks in SP learning. We accept the tiny
|
| 91 |
+
// perf loss for correctness; the hot overlap kernel has no transcendentals.
|
| 92 |
+
cmd.args([
|
| 93 |
+
"--ptx",
|
| 94 |
+
"-O3",
|
| 95 |
+
"-rdc=true",
|
| 96 |
+
"-arch",
|
| 97 |
+
&arch,
|
| 98 |
+
]);
|
| 99 |
+
// `cooperative_groups::this_cluster()` is not declared for Ampere
|
| 100 |
+
// device compiles in CUDA 12.x, even if guarded by __CUDA_ARCH__ in
|
| 101 |
+
// some nvcc front-end phases. Define an explicit build-time kill
|
| 102 |
+
// switch for all non-Hopper targets so sm_86/A10G only sees the
|
| 103 |
+
// cooperative-grid path.
|
| 104 |
+
if !arch.starts_with("sm_90") {
|
| 105 |
+
cmd.arg("-DHTM_DISABLE_CLUSTER=1");
|
| 106 |
+
}
|
| 107 |
+
if let Some(cc) = &host_compiler {
|
| 108 |
+
cmd.args(["-ccbin", cc]);
|
| 109 |
+
}
|
| 110 |
+
cmd.arg("-o").arg(&ptx).arg(&src);
|
| 111 |
+
let status = cmd
|
| 112 |
+
.status()
|
| 113 |
+
.unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
|
| 114 |
+
if !status.success() {
|
| 115 |
+
panic!("nvcc failed for {}", src.display());
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
if let Some(ver) = &ptx_version_override {
|
| 119 |
+
// Read, patch, write.
|
| 120 |
+
let text = std::fs::read_to_string(&ptx)
|
| 121 |
+
.unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
|
| 122 |
+
// Match `.version X.Y` where X and Y are digits. Replace whole line.
|
| 123 |
+
let patched: String = text
|
| 124 |
+
.lines()
|
| 125 |
+
.map(|line| {
|
| 126 |
+
let t = line.trim_start();
|
| 127 |
+
if t.starts_with(".version ") {
|
| 128 |
+
format!(".version {ver}")
|
| 129 |
+
} else {
|
| 130 |
+
line.to_string()
|
| 131 |
+
}
|
| 132 |
+
})
|
| 133 |
+
.collect::<Vec<_>>()
|
| 134 |
+
.join("\n");
|
| 135 |
+
std::fs::write(&ptx, patched)
|
| 136 |
+
.unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Export OUT_DIR for include_str! in Rust.
|
| 141 |
+
println!(
|
| 142 |
+
"cargo:rustc-env=HTM_GPU_PTX_DIR={}",
|
| 143 |
+
out_dir.display()
|
| 144 |
+
);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
fn find_nvcc() -> String {
|
| 148 |
+
if let Ok(n) = env::var("NVCC") {
|
| 149 |
+
return n;
|
| 150 |
+
}
|
| 151 |
+
// Try PATH.
|
| 152 |
+
if Command::new("nvcc").arg("--version").output().is_ok() {
|
| 153 |
+
return "nvcc".into();
|
| 154 |
+
}
|
| 155 |
+
for cand in [
|
| 156 |
+
"/usr/local/cuda-12.1/bin/nvcc",
|
| 157 |
+
"/usr/local/cuda/bin/nvcc",
|
| 158 |
+
"/usr/local/cuda-12/bin/nvcc",
|
| 159 |
+
] {
|
| 160 |
+
if std::path::Path::new(cand).exists() {
|
| 161 |
+
return cand.into();
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
panic!(
|
| 165 |
+
"nvcc not found. Set $NVCC or install CUDA toolkit. \
|
| 166 |
+
Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
|
| 167 |
+
);
|
| 168 |
+
}
|
overlay/htm_rust/pyproject.toml
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["maturin>=1.4,<2.0"]
|
| 3 |
-
build-backend = "maturin"
|
| 4 |
-
|
| 5 |
-
[project]
|
| 6 |
-
name = "htm_rust"
|
| 7 |
-
version = "0.1.0"
|
| 8 |
-
description = "Numenta BAMI-spec HTM (Spatial Pooler + Temporal Memory) in Rust with pyo3 bindings"
|
| 9 |
-
requires-python = ">=3.11"
|
| 10 |
-
classifiers = [
|
| 11 |
-
"Programming Language :: Rust",
|
| 12 |
-
"Programming Language :: Python :: Implementation :: CPython",
|
| 13 |
-
]
|
| 14 |
-
|
| 15 |
-
[tool.maturin]
|
| 16 |
-
features = ["pyo3/extension-module"]
|
| 17 |
-
module-name = "htm_rust"
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["maturin>=1.4,<2.0"]
|
| 3 |
+
build-backend = "maturin"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "htm_rust"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Numenta BAMI-spec HTM (Spatial Pooler + Temporal Memory) in Rust with pyo3 bindings"
|
| 9 |
+
requires-python = ">=3.11"
|
| 10 |
+
classifiers = [
|
| 11 |
+
"Programming Language :: Rust",
|
| 12 |
+
"Programming Language :: Python :: Implementation :: CPython",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[tool.maturin]
|
| 16 |
+
features = ["pyo3/extension-module"]
|
| 17 |
+
module-name = "htm_rust"
|
overlay/htm_rust/src/gpu/fused.rs
CHANGED
|
@@ -1,663 +1,702 @@
|
|
| 1 |
-
//! Fused HTM megakernel launcher.
|
| 2 |
-
//!
|
| 3 |
-
//! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
|
| 4 |
-
//! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
|
| 5 |
-
//! the kernel design and the cross-block coherence strategy (grid barrier
|
| 6 |
-
//! via device counter with all blocks concurrently resident).
|
| 7 |
-
//!
|
| 8 |
-
//! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
|
| 9 |
-
//! probes the device SM count at construction and caps grid_dim.x
|
| 10 |
-
//! accordingly — otherwise the grid barrier deadlocks.
|
| 11 |
-
//!
|
| 12 |
-
//! Semantic change from the top-K pipeline: activation is per-column
|
| 13 |
-
//! threshold-based (local lateral inhibition) instead of global top-K.
|
| 14 |
-
//! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
|
| 15 |
-
//! the sparsity target. This is a real architectural change and is
|
| 16 |
-
//! documented in `docs/GPU_HTM.md`.
|
| 17 |
-
|
| 18 |
-
#![cfg(feature = "gpu")]
|
| 19 |
-
|
| 20 |
-
use std::ffi::CString;
|
| 21 |
-
use std::sync::Arc;
|
| 22 |
-
|
| 23 |
-
use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
|
| 24 |
-
LaunchConfig};
|
| 25 |
-
use cudarc::nvrtc::Ptx;
|
| 26 |
-
|
| 27 |
-
use super::sp_gpu::SpatialPoolerGpu;
|
| 28 |
-
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 29 |
-
|
| 30 |
-
const PTX_HTM_FUSED: &str =
|
| 31 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
|
| 32 |
-
|
| 33 |
-
/// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
|
| 34 |
-
///
|
| 35 |
-
/// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
|
| 36 |
-
/// C-side `FusedPtrs` still has the field at the same byte offset; removing
|
| 37 |
-
/// it here would shift all subsequent fields and break the layout. Worker A
|
| 38 |
-
/// will eventually delete the field from both sides once the kernel is
|
| 39 |
-
/// updated; until then we zero it.
|
| 40 |
-
#[repr(C)]
|
| 41 |
-
#[derive(Clone, Copy)]
|
| 42 |
-
pub struct FusedPtrs {
|
| 43 |
-
pub syn_bit: u64,
|
| 44 |
-
pub syn_perm: u64,
|
| 45 |
-
pub boost: u64,
|
| 46 |
-
pub active_duty: u64,
|
| 47 |
-
pub inhibition_threshold: u64,
|
| 48 |
-
pub seg_cell_id: u64,
|
| 49 |
-
pub seg_syn_count: u64,
|
| 50 |
-
pub syn_presyn: u64,
|
| 51 |
-
pub tm_syn_perm: u64,
|
| 52 |
-
pub cell_seg_count: u64,
|
| 53 |
-
pub cell_active_a: u64,
|
| 54 |
-
pub cell_active_b: u64,
|
| 55 |
-
pub cell_winner_a: u64,
|
| 56 |
-
pub cell_winner_b: u64,
|
| 57 |
-
pub inputs: u64,
|
| 58 |
-
pub cols_out: u64,
|
| 59 |
-
pub anom_out: u64,
|
| 60 |
-
/// ABI-compat dummy — always 0. No device memory is allocated for this
|
| 61 |
-
/// field; the cluster barrier replaces the old software DLB barrier.
|
| 62 |
-
pub barrier_counters: u64,
|
| 63 |
-
pub step_scratch: u64,
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
unsafe impl DeviceRepr for FusedPtrs {}
|
| 67 |
-
|
| 68 |
-
/// Launch-time config — matches C-side `FusedConfig` 1:1.
|
| 69 |
-
#[repr(C)]
|
| 70 |
-
#[derive(Clone, Copy)]
|
| 71 |
-
pub struct FusedConfig {
|
| 72 |
-
pub input_bits: u32,
|
| 73 |
-
pub n_columns: u32,
|
| 74 |
-
pub synapses_per_col: u32,
|
| 75 |
-
pub conn_thr: f32,
|
| 76 |
-
pub sp_inc: f32,
|
| 77 |
-
pub sp_dec: f32,
|
| 78 |
-
pub sparsity_target: f32,
|
| 79 |
-
pub duty_alpha: f32,
|
| 80 |
-
pub thr_adapt_rate: f32,
|
| 81 |
-
pub cells_per_column: u32,
|
| 82 |
-
pub n_cells: u32,
|
| 83 |
-
pub bits_words: u32,
|
| 84 |
-
pub max_segments_per_cell: u32,
|
| 85 |
-
pub synapses_per_segment: u32,
|
| 86 |
-
pub activation_threshold: u32,
|
| 87 |
-
pub learning_threshold: u32,
|
| 88 |
-
pub max_new_synapses: u32,
|
| 89 |
-
pub conn_thr_i16: i32,
|
| 90 |
-
pub perm_inc_i16: i32,
|
| 91 |
-
pub perm_dec_i16: i32,
|
| 92 |
-
pub predicted_seg_dec_i16: i32,
|
| 93 |
-
pub initial_perm_i16: i32,
|
| 94 |
-
pub t: u32,
|
| 95 |
-
pub learn: u32,
|
| 96 |
-
pub iter_seed: u32,
|
| 97 |
-
pub cooperative_grid_sync: u32,
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
unsafe impl DeviceRepr for FusedConfig {}
|
| 101 |
-
|
| 102 |
-
/// Cluster launch parameters probed at construction time.
|
| 103 |
-
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
| 104 |
-
pub(crate) struct ClusterInfo {
|
| 105 |
-
/// Maximum cluster size supported by this device (0 = cluster unsupported).
|
| 106 |
-
pub max_cluster_size: u32,
|
| 107 |
-
}
|
| 108 |
-
|
| 109 |
-
// There is only ONE launch mode: non-cooperative launch with Hopper Thread
|
| 110 |
-
// Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
|
| 111 |
-
// software DLB barrier and the cooperative-launch path are both removed.
|
| 112 |
-
// Cluster barriers replace both.
|
| 113 |
-
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
| 114 |
-
pub(crate) struct FusedLaunchPlan {
|
| 115 |
-
pub grid_dim_x: u32,
|
| 116 |
-
pub block_dim_x: u32,
|
| 117 |
-
pub cooperative_grid_limit: u32,
|
| 118 |
-
pub sm_count: u32,
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
fn fused_grid_cap_override() -> Option<u32> {
|
| 122 |
-
std::env::var("HTM_FUSED_GRID_CAP")
|
| 123 |
-
.ok()
|
| 124 |
-
.and_then(|s| s.parse::<u32>().ok())
|
| 125 |
-
.map(|v| v.max(1))
|
| 126 |
-
}
|
| 127 |
-
|
| 128 |
-
pub(crate) fn plan_fused_launch(
|
| 129 |
-
sm_count: u32,
|
| 130 |
-
cooperative_supported: bool,
|
| 131 |
-
cooperative_grid_limit: u32,
|
| 132 |
-
grid_cap_override: Option<u32>,
|
| 133 |
-
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
-
let sm_count = sm_count.max(1);
|
| 135 |
-
// 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
|
| 136 |
-
// regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
|
| 137 |
-
// 256 regs/thread which is ample. Compensate with more blocks via
|
| 138 |
-
// cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
|
| 139 |
-
// 1024 works fine, but 256 is safe everywhere.
|
| 140 |
-
let block_dim_x = 256u32;
|
| 141 |
-
|
| 142 |
-
// Cluster launch path: cooperative launch is not required. Keep the probe
|
| 143 |
-
// result for residency estimation only.
|
| 144 |
-
if !cooperative_supported {
|
| 145 |
-
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
// Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
|
| 149 |
-
// Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
|
| 150 |
-
let default_grid_cap = 16u32;
|
| 151 |
-
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
|
| 152 |
-
let resident_bound = if cooperative_grid_limit > 0 {
|
| 153 |
-
cooperative_grid_limit.max(sm_count * 2)
|
| 154 |
-
} else {
|
| 155 |
-
sm_count * 2
|
| 156 |
-
};
|
| 157 |
-
Ok(FusedLaunchPlan {
|
| 158 |
-
grid_dim_x: resident_bound.min(grid_cap).max(1),
|
| 159 |
-
block_dim_x,
|
| 160 |
-
cooperative_grid_limit: resident_bound,
|
| 161 |
-
sm_count,
|
| 162 |
-
})
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
pub(
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
pub
|
| 193 |
-
pub
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
let
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
let
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
let
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
//
|
| 496 |
-
//
|
| 497 |
-
|
| 498 |
-
//
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
//
|
| 598 |
-
let
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
)
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Fused HTM megakernel launcher.
|
| 2 |
+
//!
|
| 3 |
+
//! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
|
| 4 |
+
//! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
|
| 5 |
+
//! the kernel design and the cross-block coherence strategy (grid barrier
|
| 6 |
+
//! via device counter with all blocks concurrently resident).
|
| 7 |
+
//!
|
| 8 |
+
//! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
|
| 9 |
+
//! probes the device SM count at construction and caps grid_dim.x
|
| 10 |
+
//! accordingly — otherwise the grid barrier deadlocks.
|
| 11 |
+
//!
|
| 12 |
+
//! Semantic change from the top-K pipeline: activation is per-column
|
| 13 |
+
//! threshold-based (local lateral inhibition) instead of global top-K.
|
| 14 |
+
//! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
|
| 15 |
+
//! the sparsity target. This is a real architectural change and is
|
| 16 |
+
//! documented in `docs/GPU_HTM.md`.
|
| 17 |
+
|
| 18 |
+
#![cfg(feature = "gpu")]
|
| 19 |
+
|
| 20 |
+
use std::ffi::CString;
|
| 21 |
+
use std::sync::Arc;
|
| 22 |
+
|
| 23 |
+
use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
|
| 24 |
+
LaunchConfig};
|
| 25 |
+
use cudarc::nvrtc::Ptx;
|
| 26 |
+
|
| 27 |
+
use super::sp_gpu::SpatialPoolerGpu;
|
| 28 |
+
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 29 |
+
|
| 30 |
+
const PTX_HTM_FUSED: &str =
|
| 31 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
|
| 32 |
+
|
| 33 |
+
/// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
|
| 34 |
+
///
|
| 35 |
+
/// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
|
| 36 |
+
/// C-side `FusedPtrs` still has the field at the same byte offset; removing
|
| 37 |
+
/// it here would shift all subsequent fields and break the layout. Worker A
|
| 38 |
+
/// will eventually delete the field from both sides once the kernel is
|
| 39 |
+
/// updated; until then we zero it.
|
| 40 |
+
#[repr(C)]
|
| 41 |
+
#[derive(Clone, Copy)]
|
| 42 |
+
pub struct FusedPtrs {
|
| 43 |
+
pub syn_bit: u64,
|
| 44 |
+
pub syn_perm: u64,
|
| 45 |
+
pub boost: u64,
|
| 46 |
+
pub active_duty: u64,
|
| 47 |
+
pub inhibition_threshold: u64,
|
| 48 |
+
pub seg_cell_id: u64,
|
| 49 |
+
pub seg_syn_count: u64,
|
| 50 |
+
pub syn_presyn: u64,
|
| 51 |
+
pub tm_syn_perm: u64,
|
| 52 |
+
pub cell_seg_count: u64,
|
| 53 |
+
pub cell_active_a: u64,
|
| 54 |
+
pub cell_active_b: u64,
|
| 55 |
+
pub cell_winner_a: u64,
|
| 56 |
+
pub cell_winner_b: u64,
|
| 57 |
+
pub inputs: u64,
|
| 58 |
+
pub cols_out: u64,
|
| 59 |
+
pub anom_out: u64,
|
| 60 |
+
/// ABI-compat dummy — always 0. No device memory is allocated for this
|
| 61 |
+
/// field; the cluster barrier replaces the old software DLB barrier.
|
| 62 |
+
pub barrier_counters: u64,
|
| 63 |
+
pub step_scratch: u64,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
unsafe impl DeviceRepr for FusedPtrs {}
|
| 67 |
+
|
| 68 |
+
/// Launch-time config — matches C-side `FusedConfig` 1:1.
|
| 69 |
+
#[repr(C)]
|
| 70 |
+
#[derive(Clone, Copy)]
|
| 71 |
+
pub struct FusedConfig {
|
| 72 |
+
pub input_bits: u32,
|
| 73 |
+
pub n_columns: u32,
|
| 74 |
+
pub synapses_per_col: u32,
|
| 75 |
+
pub conn_thr: f32,
|
| 76 |
+
pub sp_inc: f32,
|
| 77 |
+
pub sp_dec: f32,
|
| 78 |
+
pub sparsity_target: f32,
|
| 79 |
+
pub duty_alpha: f32,
|
| 80 |
+
pub thr_adapt_rate: f32,
|
| 81 |
+
pub cells_per_column: u32,
|
| 82 |
+
pub n_cells: u32,
|
| 83 |
+
pub bits_words: u32,
|
| 84 |
+
pub max_segments_per_cell: u32,
|
| 85 |
+
pub synapses_per_segment: u32,
|
| 86 |
+
pub activation_threshold: u32,
|
| 87 |
+
pub learning_threshold: u32,
|
| 88 |
+
pub max_new_synapses: u32,
|
| 89 |
+
pub conn_thr_i16: i32,
|
| 90 |
+
pub perm_inc_i16: i32,
|
| 91 |
+
pub perm_dec_i16: i32,
|
| 92 |
+
pub predicted_seg_dec_i16: i32,
|
| 93 |
+
pub initial_perm_i16: i32,
|
| 94 |
+
pub t: u32,
|
| 95 |
+
pub learn: u32,
|
| 96 |
+
pub iter_seed: u32,
|
| 97 |
+
pub cooperative_grid_sync: u32,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
unsafe impl DeviceRepr for FusedConfig {}
|
| 101 |
+
|
| 102 |
+
/// Cluster launch parameters probed at construction time.
|
| 103 |
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
| 104 |
+
pub(crate) struct ClusterInfo {
|
| 105 |
+
/// Maximum cluster size supported by this device (0 = cluster unsupported).
|
| 106 |
+
pub max_cluster_size: u32,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// There is only ONE launch mode: non-cooperative launch with Hopper Thread
|
| 110 |
+
// Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
|
| 111 |
+
// software DLB barrier and the cooperative-launch path are both removed.
|
| 112 |
+
// Cluster barriers replace both.
|
| 113 |
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
| 114 |
+
pub(crate) struct FusedLaunchPlan {
|
| 115 |
+
pub grid_dim_x: u32,
|
| 116 |
+
pub block_dim_x: u32,
|
| 117 |
+
pub cooperative_grid_limit: u32,
|
| 118 |
+
pub sm_count: u32,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
fn fused_grid_cap_override() -> Option<u32> {
|
| 122 |
+
std::env::var("HTM_FUSED_GRID_CAP")
|
| 123 |
+
.ok()
|
| 124 |
+
.and_then(|s| s.parse::<u32>().ok())
|
| 125 |
+
.map(|v| v.max(1))
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
pub(crate) fn plan_fused_launch(
|
| 129 |
+
sm_count: u32,
|
| 130 |
+
cooperative_supported: bool,
|
| 131 |
+
cooperative_grid_limit: u32,
|
| 132 |
+
grid_cap_override: Option<u32>,
|
| 133 |
+
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
+
let sm_count = sm_count.max(1);
|
| 135 |
+
// 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
|
| 136 |
+
// regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
|
| 137 |
+
// 256 regs/thread which is ample. Compensate with more blocks via
|
| 138 |
+
// cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
|
| 139 |
+
// 1024 works fine, but 256 is safe everywhere.
|
| 140 |
+
let block_dim_x = 256u32;
|
| 141 |
+
|
| 142 |
+
// Cluster launch path: cooperative launch is not required. Keep the probe
|
| 143 |
+
// result for residency estimation only.
|
| 144 |
+
if !cooperative_supported {
|
| 145 |
+
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
// Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
|
| 149 |
+
// Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
|
| 150 |
+
let default_grid_cap = 16u32;
|
| 151 |
+
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
|
| 152 |
+
let resident_bound = if cooperative_grid_limit > 0 {
|
| 153 |
+
cooperative_grid_limit.max(sm_count * 2)
|
| 154 |
+
} else {
|
| 155 |
+
sm_count * 2
|
| 156 |
+
};
|
| 157 |
+
Ok(FusedLaunchPlan {
|
| 158 |
+
grid_dim_x: resident_bound.min(grid_cap).max(1),
|
| 159 |
+
block_dim_x,
|
| 160 |
+
cooperative_grid_limit: resident_bound,
|
| 161 |
+
sm_count,
|
| 162 |
+
})
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
pub(crate) fn plan_batched_grid_dim(
|
| 166 |
+
grid_dim_x: u32,
|
| 167 |
+
cooperative_grid_limit: u32,
|
| 168 |
+
batch_regions: usize,
|
| 169 |
+
use_cluster: bool,
|
| 170 |
+
) -> Result<u32, String> {
|
| 171 |
+
if use_cluster {
|
| 172 |
+
return Ok(grid_dim_x.max(1));
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
let batch_regions = batch_regions.max(1) as u32;
|
| 176 |
+
if cooperative_grid_limit == 0 {
|
| 177 |
+
return Err("COOPERATIVE_LAUNCH_TOO_LARGE: cooperative launch limit unavailable".into());
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
let max_grid_x = cooperative_grid_limit / batch_regions;
|
| 181 |
+
if max_grid_x == 0 {
|
| 182 |
+
return Err(format!(
|
| 183 |
+
"COOPERATIVE_LAUNCH_TOO_LARGE: batch_regions={batch_regions} exceeds cooperative_grid_limit={cooperative_grid_limit}"
|
| 184 |
+
));
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
Ok(grid_dim_x.min(max_grid_x).max(1))
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
pub(super) struct RawFusedKernel {
|
| 191 |
+
module: sys::CUmodule,
|
| 192 |
+
pub(super) function: sys::CUfunction,
|
| 193 |
+
pub(super) function_batched: sys::CUfunction,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
unsafe impl Send for RawFusedKernel {}
|
| 197 |
+
unsafe impl Sync for RawFusedKernel {}
|
| 198 |
+
|
| 199 |
+
impl Drop for RawFusedKernel {
|
| 200 |
+
fn drop(&mut self) {
|
| 201 |
+
unsafe {
|
| 202 |
+
let _ = result::module::unload(self.module);
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/// Owns fused-path-only device state:
|
| 208 |
+
/// - per-column inhibition threshold (replaces global top-K)
|
| 209 |
+
/// - ping-pong cell_active/cell_winner bitsets
|
| 210 |
+
/// - step_scratch (n_active, n_unpred per timestep)
|
| 211 |
+
/// - cluster launch capability info
|
| 212 |
+
pub struct FusedState {
|
| 213 |
+
dev: Arc<CudaDevice>,
|
| 214 |
+
pub(super) raw_kernel: RawFusedKernel,
|
| 215 |
+
|
| 216 |
+
pub inhibition_threshold: CudaSlice<f32>,
|
| 217 |
+
pub cell_active_bits_a: CudaSlice<u32>,
|
| 218 |
+
pub cell_active_bits_b: CudaSlice<u32>,
|
| 219 |
+
pub cell_winner_bits_a: CudaSlice<u32>,
|
| 220 |
+
pub cell_winner_bits_b: CudaSlice<u32>,
|
| 221 |
+
pub step_scratch: CudaSlice<u32>, // length 6
|
| 222 |
+
|
| 223 |
+
pub grid_dim_x: u32,
|
| 224 |
+
pub block_dim_x: u32,
|
| 225 |
+
pub cooperative_grid_limit: u32,
|
| 226 |
+
pub iter_counter: u32,
|
| 227 |
+
|
| 228 |
+
/// Hopper cluster launch capability (0 = unsupported).
|
| 229 |
+
pub cluster_info: ClusterInfo,
|
| 230 |
+
|
| 231 |
+
// Config mirror (read-only after init).
|
| 232 |
+
#[allow(dead_code)]
|
| 233 |
+
pub initial_threshold: f32,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
impl FusedState {
|
| 237 |
+
pub fn new(
|
| 238 |
+
dev: Arc<CudaDevice>,
|
| 239 |
+
n_columns: usize,
|
| 240 |
+
cells_per_column: usize,
|
| 241 |
+
initial_threshold: f32,
|
| 242 |
+
) -> Result<Self, DriverError> {
|
| 243 |
+
let n_cells = n_columns * cells_per_column;
|
| 244 |
+
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
|
| 245 |
+
let bits_words = n_cells / 32;
|
| 246 |
+
|
| 247 |
+
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
|
| 248 |
+
let init_vec = vec![initial_threshold; n_columns];
|
| 249 |
+
dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?;
|
| 250 |
+
|
| 251 |
+
let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
|
| 252 |
+
let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
|
| 253 |
+
let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
|
| 254 |
+
let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
|
| 255 |
+
let step_scratch = dev.alloc_zeros::<u32>(6)?;
|
| 256 |
+
|
| 257 |
+
unsafe {
|
| 258 |
+
result::ctx::set_current(*dev.cu_primary_ctx())?;
|
| 259 |
+
}
|
| 260 |
+
if dev.get_func("htm_fused", "htm_fused_step").is_none() {
|
| 261 |
+
dev.load_ptx(
|
| 262 |
+
Ptx::from_src(PTX_HTM_FUSED),
|
| 263 |
+
"htm_fused",
|
| 264 |
+
&["htm_fused_step", "htm_fused_step_batched"],
|
| 265 |
+
)?;
|
| 266 |
+
}
|
| 267 |
+
let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes");
|
| 268 |
+
let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?;
|
| 269 |
+
let function = unsafe {
|
| 270 |
+
result::module::get_function(module, CString::new("htm_fused_step").unwrap())
|
| 271 |
+
}?;
|
| 272 |
+
let function_batched = unsafe {
|
| 273 |
+
result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap())
|
| 274 |
+
}?;
|
| 275 |
+
|
| 276 |
+
// Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in).
|
| 277 |
+
// Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on
|
| 278 |
+
// every launched kernel function, otherwise cuLaunchKernelEx rejects
|
| 279 |
+
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
|
| 280 |
+
unsafe {
|
| 281 |
+
let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
|
| 282 |
+
// Ignore errors: older CUDA may lack the attribute, in which case
|
| 283 |
+
// only portable sizes (<= 8) work — plan_fused_launch caps at 8.
|
| 284 |
+
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
|
| 285 |
+
let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1);
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// Probe SM count.
|
| 289 |
+
let sm_count = match dev.attribute(
|
| 290 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
| 291 |
+
) {
|
| 292 |
+
Ok(v) => v as u32,
|
| 293 |
+
Err(_) => 16u32,
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
// T1: Probe Hopper cluster launch capability.
|
| 297 |
+
let max_cluster_size = match dev.attribute(
|
| 298 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
|
| 299 |
+
) {
|
| 300 |
+
Ok(v) if v > 0 => {
|
| 301 |
+
// H200/sm_90a supports up to 16 blocks per cluster.
|
| 302 |
+
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
|
| 303 |
+
// Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster).
|
| 304 |
+
16u32
|
| 305 |
+
}
|
| 306 |
+
_ => 0u32,
|
| 307 |
+
};
|
| 308 |
+
eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
|
| 309 |
+
let cluster_info = ClusterInfo { max_cluster_size };
|
| 310 |
+
|
| 311 |
+
let cooperative_supported = matches!(
|
| 312 |
+
dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
|
| 313 |
+
Ok(v) if v > 0
|
| 314 |
+
);
|
| 315 |
+
let cooperative_grid_limit = if cooperative_supported {
|
| 316 |
+
let blocks_per_sm = unsafe {
|
| 317 |
+
// Must match plan_fused_launch(): the A10G/Ampere-safe fused
|
| 318 |
+
// kernel launch uses 256 threads/block, not the historical
|
| 319 |
+
// 1024-thread Hopper occupancy probe.
|
| 320 |
+
result::occupancy::max_active_block_per_multiprocessor(function, 256, 0)
|
| 321 |
+
}
|
| 322 |
+
.ok()
|
| 323 |
+
.map(|v| v.max(0) as u32)
|
| 324 |
+
.unwrap_or(0);
|
| 325 |
+
sm_count.saturating_mul(blocks_per_sm)
|
| 326 |
+
} else {
|
| 327 |
+
0
|
| 328 |
+
};
|
| 329 |
+
let launch_plan = plan_fused_launch(
|
| 330 |
+
sm_count,
|
| 331 |
+
cooperative_supported,
|
| 332 |
+
cooperative_grid_limit,
|
| 333 |
+
fused_grid_cap_override(),
|
| 334 |
+
)
|
| 335 |
+
.map_err(|msg| {
|
| 336 |
+
// Surface as a CUDA-ish error so callers can propagate.
|
| 337 |
+
eprintln!("[htm_rust] FATAL: {msg}");
|
| 338 |
+
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
|
| 339 |
+
})?;
|
| 340 |
+
|
| 341 |
+
eprintln!(
|
| 342 |
+
"[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
|
| 343 |
+
launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
|
| 344 |
+
cluster_info.max_cluster_size,
|
| 345 |
+
);
|
| 346 |
+
|
| 347 |
+
Ok(Self {
|
| 348 |
+
dev,
|
| 349 |
+
raw_kernel: RawFusedKernel { module, function, function_batched },
|
| 350 |
+
inhibition_threshold,
|
| 351 |
+
cell_active_bits_a,
|
| 352 |
+
cell_active_bits_b,
|
| 353 |
+
cell_winner_bits_a,
|
| 354 |
+
cell_winner_bits_b,
|
| 355 |
+
step_scratch,
|
| 356 |
+
grid_dim_x: launch_plan.grid_dim_x,
|
| 357 |
+
block_dim_x: launch_plan.block_dim_x,
|
| 358 |
+
cooperative_grid_limit: launch_plan.cooperative_grid_limit,
|
| 359 |
+
iter_counter: 0,
|
| 360 |
+
cluster_info,
|
| 361 |
+
initial_threshold,
|
| 362 |
+
})
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
/// Reset fused state. Called at region.reset().
|
| 366 |
+
pub fn reset(&mut self) -> Result<(), DriverError> {
|
| 367 |
+
self.dev.memset_zeros(&mut self.cell_active_bits_a)?;
|
| 368 |
+
self.dev.memset_zeros(&mut self.cell_active_bits_b)?;
|
| 369 |
+
self.dev.memset_zeros(&mut self.cell_winner_bits_a)?;
|
| 370 |
+
self.dev.memset_zeros(&mut self.cell_winner_bits_b)?;
|
| 371 |
+
self.dev.memset_zeros(&mut self.step_scratch)?;
|
| 372 |
+
// Do NOT reset inhibition_threshold — it's learned state. A hard
|
| 373 |
+
// reset of TM state should NOT forget the sparsity calibration.
|
| 374 |
+
Ok(())
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
/// Launch the fused megakernel. Processes all T timesteps in one kernel.
|
| 379 |
+
///
|
| 380 |
+
/// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)`
|
| 381 |
+
/// when the device supports cluster launch, otherwise falls back to a plain
|
| 382 |
+
/// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the
|
| 383 |
+
/// entire grid fits in one cluster.
|
| 384 |
+
#[allow(clippy::too_many_arguments)]
|
| 385 |
+
pub fn launch_fused(
|
| 386 |
+
sp: &mut SpatialPoolerGpu,
|
| 387 |
+
tm: &mut TemporalMemoryGpu,
|
| 388 |
+
fused: &mut FusedState,
|
| 389 |
+
inputs_flat: &CudaSlice<u8>,
|
| 390 |
+
cols_out: &mut CudaSlice<u8>,
|
| 391 |
+
anom_out: &mut CudaSlice<f32>,
|
| 392 |
+
t: usize,
|
| 393 |
+
input_bits: usize,
|
| 394 |
+
learn: bool,
|
| 395 |
+
) -> Result<(), DriverError> {
|
| 396 |
+
// Reset step_scratch before each launch (safe re-entry).
|
| 397 |
+
sp.dev_ref().memset_zeros(&mut fused.step_scratch)?;
|
| 398 |
+
|
| 399 |
+
fused.iter_counter = fused.iter_counter.wrapping_add(1);
|
| 400 |
+
|
| 401 |
+
let cfg = FusedConfig {
|
| 402 |
+
input_bits: input_bits as u32,
|
| 403 |
+
n_columns: sp.n_columns_accessor() as u32,
|
| 404 |
+
synapses_per_col: sp.synapses_per_col_accessor() as u32,
|
| 405 |
+
conn_thr: sp.conn_thr_accessor(),
|
| 406 |
+
sp_inc: sp.inc_accessor(),
|
| 407 |
+
sp_dec: sp.dec_accessor(),
|
| 408 |
+
sparsity_target: sp.sparsity_accessor(),
|
| 409 |
+
duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0),
|
| 410 |
+
thr_adapt_rate: 0.001f32,
|
| 411 |
+
cells_per_column: tm.cells_per_column as u32,
|
| 412 |
+
n_cells: tm.n_cells as u32,
|
| 413 |
+
bits_words: tm.bits_words as u32,
|
| 414 |
+
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 415 |
+
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 416 |
+
activation_threshold: tm.activation_threshold,
|
| 417 |
+
learning_threshold: tm.learning_threshold,
|
| 418 |
+
max_new_synapses: tm.max_new_synapse_count,
|
| 419 |
+
conn_thr_i16: tm.conn_thr_i16 as i32,
|
| 420 |
+
perm_inc_i16: tm.perm_inc_i16 as i32,
|
| 421 |
+
perm_dec_i16: tm.perm_dec_i16 as i32,
|
| 422 |
+
predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32,
|
| 423 |
+
initial_perm_i16: tm.initial_perm_i16 as i32,
|
| 424 |
+
t: t as u32,
|
| 425 |
+
learn: if learn { 1 } else { 0 },
|
| 426 |
+
iter_seed: fused.iter_counter,
|
| 427 |
+
cooperative_grid_sync: 1,
|
| 428 |
+
};
|
| 429 |
+
|
| 430 |
+
let ptrs = FusedPtrs {
|
| 431 |
+
syn_bit: *sp.syn_bit_accessor().device_ptr(),
|
| 432 |
+
syn_perm: *sp.syn_perm_accessor().device_ptr(),
|
| 433 |
+
boost: *sp.boost_accessor().device_ptr(),
|
| 434 |
+
active_duty: *sp.active_duty_accessor().device_ptr(),
|
| 435 |
+
inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
|
| 436 |
+
seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(),
|
| 437 |
+
seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(),
|
| 438 |
+
syn_presyn: *tm.syn_presyn_accessor().device_ptr(),
|
| 439 |
+
tm_syn_perm: *tm.syn_perm_accessor().device_ptr(),
|
| 440 |
+
cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(),
|
| 441 |
+
cell_active_a: *fused.cell_active_bits_a.device_ptr(),
|
| 442 |
+
cell_active_b: *fused.cell_active_bits_b.device_ptr(),
|
| 443 |
+
cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
|
| 444 |
+
cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
|
| 445 |
+
inputs: *inputs_flat.device_ptr(),
|
| 446 |
+
cols_out: *cols_out.device_ptr(),
|
| 447 |
+
anom_out: *anom_out.device_ptr(),
|
| 448 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 449 |
+
step_scratch: *fused.step_scratch.device_ptr(),
|
| 450 |
+
};
|
| 451 |
+
|
| 452 |
+
let grid_x = fused.grid_dim_x;
|
| 453 |
+
let block_x = fused.block_dim_x;
|
| 454 |
+
let cu_stream = *sp.dev_ref().cu_stream();
|
| 455 |
+
let use_cluster = fused.cluster_info.max_cluster_size > 0;
|
| 456 |
+
|
| 457 |
+
unsafe {
|
| 458 |
+
result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?;
|
| 459 |
+
let mut kernel_params: [*mut std::ffi::c_void; 2] = [
|
| 460 |
+
(&ptrs as *const FusedPtrs).cast_mut().cast(),
|
| 461 |
+
(&cfg as *const FusedConfig).cast_mut().cast(),
|
| 462 |
+
];
|
| 463 |
+
|
| 464 |
+
if use_cluster {
|
| 465 |
+
// T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION.
|
| 466 |
+
// cluster_dim=(16,1,1) maps the entire single-region grid into one cluster.
|
| 467 |
+
let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
|
| 468 |
+
attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
| 469 |
+
attr.value.clusterDim.x = 16;
|
| 470 |
+
attr.value.clusterDim.y = 1;
|
| 471 |
+
attr.value.clusterDim.z = 1;
|
| 472 |
+
|
| 473 |
+
let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
|
| 474 |
+
launch_cfg.gridDimX = grid_x;
|
| 475 |
+
launch_cfg.gridDimY = 1;
|
| 476 |
+
launch_cfg.gridDimZ = 1;
|
| 477 |
+
launch_cfg.blockDimX = block_x;
|
| 478 |
+
launch_cfg.blockDimY = 1;
|
| 479 |
+
launch_cfg.blockDimZ = 1;
|
| 480 |
+
launch_cfg.sharedMemBytes = 0;
|
| 481 |
+
launch_cfg.hStream = cu_stream;
|
| 482 |
+
launch_cfg.numAttrs = 1;
|
| 483 |
+
launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
|
| 484 |
+
|
| 485 |
+
let ret = sys::lib().cuLaunchKernelEx(
|
| 486 |
+
&launch_cfg as *const sys::CUlaunchConfig,
|
| 487 |
+
fused.raw_kernel.function,
|
| 488 |
+
kernel_params.as_mut_ptr(),
|
| 489 |
+
std::ptr::null_mut(),
|
| 490 |
+
);
|
| 491 |
+
if ret != sys::CUresult::CUDA_SUCCESS {
|
| 492 |
+
return Err(DriverError(ret));
|
| 493 |
+
}
|
| 494 |
+
} else {
|
| 495 |
+
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
|
| 496 |
+
// grid.sync() for cross-block synchronization which REQUIRES
|
| 497 |
+
// cuLaunchCooperativeKernel (normal launch silently crashes on
|
| 498 |
+
// the first grid.sync() call).
|
| 499 |
+
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 500 |
+
fused.raw_kernel.function,
|
| 501 |
+
grid_x, 1, 1,
|
| 502 |
+
block_x, 1, 1,
|
| 503 |
+
0, // sharedMemBytes
|
| 504 |
+
cu_stream,
|
| 505 |
+
kernel_params.as_mut_ptr(),
|
| 506 |
+
);
|
| 507 |
+
if ret != sys::CUresult::CUDA_SUCCESS {
|
| 508 |
+
return Err(DriverError(ret));
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
Ok(())
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
/// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel
|
| 517 |
+
/// body; each block reads its region's FusedPtrs from a device-side array
|
| 518 |
+
/// indexed by blockIdx.y. All regions share the same config (same
|
| 519 |
+
/// input_bits/n_columns/etc.) so we pass one FusedConfig.
|
| 520 |
+
///
|
| 521 |
+
/// This breaks through the CUDA cooperative-kernel device-level
|
| 522 |
+
/// serialization: multiple cooperative launches are serialized regardless
|
| 523 |
+
/// of stream, but one cooperative launch with grid.y=B processes all
|
| 524 |
+
/// regions in a single invocation — ~B× speedup vs B sequential launches.
|
| 525 |
+
#[allow(clippy::too_many_arguments)]
|
| 526 |
+
/// Low-level raw-pointer entry, called by PyO3 binding which holds the
|
| 527 |
+
/// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live,
|
| 528 |
+
/// uniquely-borrowed region. All regions must be distinct.
|
| 529 |
+
pub(super) fn launch_fused_batched_raw(
|
| 530 |
+
region_ptrs: &[*mut super::HTMRegionGpu],
|
| 531 |
+
inputs_per_region: &[u64],
|
| 532 |
+
cols_per_region: &[u64],
|
| 533 |
+
anom_per_region: &[u64],
|
| 534 |
+
t: usize,
|
| 535 |
+
input_bits: usize,
|
| 536 |
+
learn: bool,
|
| 537 |
+
) -> Result<(), DriverError> {
|
| 538 |
+
let b = region_ptrs.len();
|
| 539 |
+
assert_eq!(inputs_per_region.len(), b);
|
| 540 |
+
assert_eq!(cols_per_region.len(), b);
|
| 541 |
+
assert_eq!(anom_per_region.len(), b);
|
| 542 |
+
assert!(b >= 1, "need at least one region");
|
| 543 |
+
|
| 544 |
+
// Reset per-region step_scratch before each launch.
|
| 545 |
+
for &rp in region_ptrs.iter() {
|
| 546 |
+
let r = unsafe { &mut *rp };
|
| 547 |
+
let dev = r.sp_gpu.dev_ref().clone();
|
| 548 |
+
dev.memset_zeros(&mut r.fused_state.step_scratch)?;
|
| 549 |
+
r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1);
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
// Shared config — all regions use identical sp/tm parameters.
|
| 553 |
+
let (grid_x, block_x, cooperative_grid_limit, function_batched, cu_stream, cu_ctx) = {
|
| 554 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 555 |
+
(
|
| 556 |
+
r0.fused_state.grid_dim_x,
|
| 557 |
+
r0.fused_state.block_dim_x,
|
| 558 |
+
r0.fused_state.cooperative_grid_limit,
|
| 559 |
+
r0.fused_state.raw_kernel.function_batched,
|
| 560 |
+
*r0.sp_gpu.dev_ref().cu_stream(),
|
| 561 |
+
*r0.sp_gpu.dev_ref().cu_primary_ctx(),
|
| 562 |
+
)
|
| 563 |
+
};
|
| 564 |
+
|
| 565 |
+
let cfg = {
|
| 566 |
+
let r = unsafe { &*region_ptrs[0] };
|
| 567 |
+
FusedConfig {
|
| 568 |
+
input_bits: input_bits as u32,
|
| 569 |
+
n_columns: r.sp_gpu.n_columns_accessor() as u32,
|
| 570 |
+
synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
|
| 571 |
+
conn_thr: r.sp_gpu.conn_thr_accessor(),
|
| 572 |
+
sp_inc: r.sp_gpu.inc_accessor(),
|
| 573 |
+
sp_dec: r.sp_gpu.dec_accessor(),
|
| 574 |
+
sparsity_target: r.sp_gpu.sparsity_accessor(),
|
| 575 |
+
duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0),
|
| 576 |
+
thr_adapt_rate: 0.001f32,
|
| 577 |
+
cells_per_column: r.tm_gpu.cells_per_column as u32,
|
| 578 |
+
n_cells: r.tm_gpu.n_cells as u32,
|
| 579 |
+
bits_words: r.tm_gpu.bits_words as u32,
|
| 580 |
+
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 581 |
+
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 582 |
+
activation_threshold: r.tm_gpu.activation_threshold,
|
| 583 |
+
learning_threshold: r.tm_gpu.learning_threshold,
|
| 584 |
+
max_new_synapses: r.tm_gpu.max_new_synapse_count,
|
| 585 |
+
conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32,
|
| 586 |
+
perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32,
|
| 587 |
+
perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32,
|
| 588 |
+
predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32,
|
| 589 |
+
initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
|
| 590 |
+
t: t as u32,
|
| 591 |
+
learn: if learn { 1 } else { 0 },
|
| 592 |
+
iter_seed: r.fused_state.iter_counter,
|
| 593 |
+
cooperative_grid_sync: 1,
|
| 594 |
+
}
|
| 595 |
+
};
|
| 596 |
+
|
| 597 |
+
// Build B FusedPtrs per-region.
|
| 598 |
+
let ptrs_vec: Vec<FusedPtrs> = (0..b)
|
| 599 |
+
.map(|i| {
|
| 600 |
+
let r = unsafe { &*region_ptrs[i] };
|
| 601 |
+
FusedPtrs {
|
| 602 |
+
syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
|
| 603 |
+
syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
|
| 604 |
+
boost: *r.sp_gpu.boost_accessor().device_ptr(),
|
| 605 |
+
active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
|
| 606 |
+
inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(),
|
| 607 |
+
seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
|
| 608 |
+
seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
|
| 609 |
+
syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
|
| 610 |
+
tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
|
| 611 |
+
cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
|
| 612 |
+
cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(),
|
| 613 |
+
cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(),
|
| 614 |
+
cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(),
|
| 615 |
+
cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(),
|
| 616 |
+
inputs: inputs_per_region[i],
|
| 617 |
+
cols_out: cols_per_region[i],
|
| 618 |
+
anom_out: anom_per_region[i],
|
| 619 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 620 |
+
step_scratch: *r.fused_state.step_scratch.device_ptr(),
|
| 621 |
+
}
|
| 622 |
+
})
|
| 623 |
+
.collect();
|
| 624 |
+
|
| 625 |
+
// Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
|
| 626 |
+
// FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
|
| 627 |
+
let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone();
|
| 628 |
+
let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?;
|
| 629 |
+
let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr();
|
| 630 |
+
|
| 631 |
+
// T10: Cluster launch for batched regions.
|
| 632 |
+
// Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
|
| 633 |
+
// occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
|
| 634 |
+
// on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
|
| 635 |
+
let use_cluster = {
|
| 636 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 637 |
+
r0.fused_state.cluster_info.max_cluster_size > 0
|
| 638 |
+
};
|
| 639 |
+
let grid_x = plan_batched_grid_dim(grid_x, cooperative_grid_limit, b, use_cluster)
|
| 640 |
+
.map_err(|msg| {
|
| 641 |
+
eprintln!("[htm_rust] FATAL: {msg}");
|
| 642 |
+
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
|
| 643 |
+
})?;
|
| 644 |
+
|
| 645 |
+
unsafe {
|
| 646 |
+
result::ctx::set_current(cu_ctx)?;
|
| 647 |
+
let mut kernel_params: [*mut std::ffi::c_void; 2] = [
|
| 648 |
+
(&ptrs_dev_ptr as *const u64).cast_mut().cast(),
|
| 649 |
+
(&cfg as *const FusedConfig).cast_mut().cast(),
|
| 650 |
+
];
|
| 651 |
+
|
| 652 |
+
if use_cluster {
|
| 653 |
+
let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
|
| 654 |
+
attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
| 655 |
+
attr.value.clusterDim.x = 16;
|
| 656 |
+
attr.value.clusterDim.y = 1;
|
| 657 |
+
attr.value.clusterDim.z = 1;
|
| 658 |
+
|
| 659 |
+
let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
|
| 660 |
+
launch_cfg.gridDimX = grid_x;
|
| 661 |
+
launch_cfg.gridDimY = b as u32;
|
| 662 |
+
launch_cfg.gridDimZ = 1;
|
| 663 |
+
launch_cfg.blockDimX = block_x;
|
| 664 |
+
launch_cfg.blockDimY = 1;
|
| 665 |
+
launch_cfg.blockDimZ = 1;
|
| 666 |
+
launch_cfg.sharedMemBytes = 0;
|
| 667 |
+
launch_cfg.hStream = cu_stream;
|
| 668 |
+
launch_cfg.numAttrs = 1;
|
| 669 |
+
launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
|
| 670 |
+
|
| 671 |
+
let ret = sys::lib().cuLaunchKernelEx(
|
| 672 |
+
&launch_cfg as *const sys::CUlaunchConfig,
|
| 673 |
+
function_batched,
|
| 674 |
+
kernel_params.as_mut_ptr(),
|
| 675 |
+
std::ptr::null_mut(),
|
| 676 |
+
);
|
| 677 |
+
if ret != sys::CUresult::CUDA_SUCCESS {
|
| 678 |
+
return Err(DriverError(ret));
|
| 679 |
+
}
|
| 680 |
+
} else {
|
| 681 |
+
// Pre-Hopper: cooperative kernel launch (grid.sync() requires it).
|
| 682 |
+
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 683 |
+
function_batched,
|
| 684 |
+
grid_x, b as u32, 1,
|
| 685 |
+
block_x, 1, 1,
|
| 686 |
+
0, // sharedMemBytes
|
| 687 |
+
cu_stream,
|
| 688 |
+
kernel_params.as_mut_ptr(),
|
| 689 |
+
);
|
| 690 |
+
if ret != sys::CUresult::CUDA_SUCCESS {
|
| 691 |
+
return Err(DriverError(ret));
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
// `ptrs_dev` is a per-call device array consumed by the async kernel.
|
| 697 |
+
// Keep it alive until the kernel has read it; otherwise dropping/freeing
|
| 698 |
+
// it immediately after launch can surface as a later unrelated CUDA error.
|
| 699 |
+
dev.synchronize()?;
|
| 700 |
+
|
| 701 |
+
Ok(())
|
| 702 |
+
}
|
overlay/htm_rust/src/gpu/kernels/htm_fused_step.cu
CHANGED
|
@@ -1,677 +1,677 @@
|
|
| 1 |
-
// Fused HTM megakernel — SP + TM, all T timesteps in a single launch.
|
| 2 |
-
//
|
| 3 |
-
// Design rationale:
|
| 4 |
-
// - Global top-K column selection requires cross-block synchronization at
|
| 5 |
-
// every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
|
| 6 |
-
// - Replace with per-column threshold activation using local lateral
|
| 7 |
-
// inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
|
| 8 |
-
// Threshold is a per-column running-EMA learned scalar that steers the
|
| 9 |
-
// column's long-run activation rate toward the global sparsity target.
|
| 10 |
-
// - This is biologically grounded (GABAergic local inhibition) and supported
|
| 11 |
-
// by HTM theory (duty-cycle boost already drives this loop; we just
|
| 12 |
-
// change which lever the EMA pulls).
|
| 13 |
-
//
|
| 14 |
-
// Launch shape:
|
| 15 |
-
// grid = min(device SM count, 16) // hard cap — see below
|
| 16 |
-
// block = 1024 threads = 32 warps
|
| 17 |
-
// Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
|
| 18 |
-
//
|
| 19 |
-
// Cross-block coherence:
|
| 20 |
-
// - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
|
| 21 |
-
// read _b; reversed at odd t.
|
| 22 |
-
// - Preferred path: cooperative launch + hardware whole-grid sync.
|
| 23 |
-
// - Fallback path: software 3-slot rotating grid barrier for devices/drivers
|
| 24 |
-
// that cannot do cooperative launch.
|
| 25 |
-
//
|
| 26 |
-
// 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
|
| 27 |
-
// cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
|
| 28 |
-
// 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
|
| 29 |
-
// leaving scheduled blocks spinning on the software grid barrier waiting for
|
| 30 |
-
// peer blocks that would never run. 16 blocks is below any realistic residency
|
| 31 |
-
// floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
|
| 32 |
-
// memory bandwidth on the spatial-pooler stage.
|
| 33 |
-
//
|
| 34 |
-
// Kernel signature uses struct-by-value for pointers and config to stay
|
| 35 |
-
// inside cudarc's launch-arg count limit.
|
| 36 |
-
|
| 37 |
-
#include <cooperative_groups.h>
|
| 38 |
-
#include <cooperative_groups/memcpy_async.h>
|
| 39 |
-
|
| 40 |
-
namespace cg = cooperative_groups;
|
| 41 |
-
|
| 42 |
-
// Maximum columns owned per cluster-block in DSMEM.
|
| 43 |
-
// Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
|
| 44 |
-
// At cluster_size=16: supports up to 256*16=4096 columns.
|
| 45 |
-
// Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM —
|
| 46 |
-
// well under the 228 KB H200 shared-memory cap.
|
| 47 |
-
#define COLS_PER_CLUSTER_BLOCK_MAX 256u
|
| 48 |
-
|
| 49 |
-
// Maximum input_bits supported by the TMA-multicast staging tile.
|
| 50 |
-
// At 32 KB this covers the production SDR width (16384 bits) with 2× headroom.
|
| 51 |
-
// Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB —
|
| 52 |
-
// well under the 228 KB H200 limit.
|
| 53 |
-
//
|
| 54 |
-
// Expected speedup from TMA multicast input staging (T9/T11):
|
| 55 |
-
// - Without staging: 16 SMs × T × (input_bits GMEM reads per timestep)
|
| 56 |
-
// - With staging: 1 TMA DMA per timestep, shared reads from L1 thereafter
|
| 57 |
-
// - Theoretical DRAM bandwidth reduction: ~16× on input reads
|
| 58 |
-
// - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
|
| 59 |
-
#define INPUT_BITS_MAX 32768u
|
| 60 |
-
|
| 61 |
-
extern "C" {
|
| 62 |
-
|
| 63 |
-
struct FusedPtrs {
|
| 64 |
-
unsigned long long syn_bit;
|
| 65 |
-
unsigned long long syn_perm;
|
| 66 |
-
unsigned long long boost;
|
| 67 |
-
unsigned long long active_duty;
|
| 68 |
-
unsigned long long inhibition_threshold;
|
| 69 |
-
unsigned long long seg_cell_id;
|
| 70 |
-
unsigned long long seg_syn_count;
|
| 71 |
-
unsigned long long syn_presyn;
|
| 72 |
-
unsigned long long tm_syn_perm;
|
| 73 |
-
unsigned long long cell_seg_count;
|
| 74 |
-
unsigned long long cell_active_a;
|
| 75 |
-
unsigned long long cell_active_b;
|
| 76 |
-
unsigned long long cell_winner_a;
|
| 77 |
-
unsigned long long cell_winner_b;
|
| 78 |
-
unsigned long long inputs;
|
| 79 |
-
unsigned long long cols_out;
|
| 80 |
-
unsigned long long anom_out;
|
| 81 |
-
unsigned long long barrier_counters;
|
| 82 |
-
unsigned long long step_scratch;
|
| 83 |
-
};
|
| 84 |
-
|
| 85 |
-
struct FusedConfig {
|
| 86 |
-
// SP constants
|
| 87 |
-
unsigned int input_bits;
|
| 88 |
-
unsigned int n_columns;
|
| 89 |
-
unsigned int synapses_per_col;
|
| 90 |
-
float conn_thr;
|
| 91 |
-
float sp_inc;
|
| 92 |
-
float sp_dec;
|
| 93 |
-
float sparsity_target;
|
| 94 |
-
float duty_alpha;
|
| 95 |
-
float thr_adapt_rate;
|
| 96 |
-
// TM constants
|
| 97 |
-
unsigned int cells_per_column;
|
| 98 |
-
unsigned int n_cells;
|
| 99 |
-
unsigned int bits_words;
|
| 100 |
-
unsigned int max_segments_per_cell;
|
| 101 |
-
unsigned int synapses_per_segment;
|
| 102 |
-
unsigned int activation_threshold;
|
| 103 |
-
unsigned int learning_threshold;
|
| 104 |
-
unsigned int max_new_synapses;
|
| 105 |
-
int conn_thr_i16;
|
| 106 |
-
int perm_inc_i16;
|
| 107 |
-
int perm_dec_i16;
|
| 108 |
-
int predicted_seg_dec_i16;
|
| 109 |
-
int initial_perm_i16;
|
| 110 |
-
// Loop constants
|
| 111 |
-
unsigned int T;
|
| 112 |
-
unsigned int learn;
|
| 113 |
-
unsigned int iter_seed;
|
| 114 |
-
unsigned int cooperative_grid_sync;
|
| 115 |
-
};
|
| 116 |
-
|
| 117 |
-
// Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
|
| 118 |
-
// Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
|
| 119 |
-
//
|
| 120 |
-
// cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
|
| 121 |
-
// in ~10-40 ns inside the cluster, with no device-level serialization.
|
| 122 |
-
// Multiple clusters (one per HTM region) run fully concurrently — bounded
|
| 123 |
-
// only by SM count (8 clusters × 16 SMs = 128 ≤ 132 on H200).
|
| 124 |
-
//
|
| 125 |
-
// The flags / expected / phase / cooperative_grid_sync parameters are kept
|
| 126 |
-
// in the signature for call-site compatibility but are unused.
|
| 127 |
-
__device__ static inline void fused_grid_barrier(cg::grid_group grid,
|
| 128 |
-
unsigned int * /* flags — unused */,
|
| 129 |
-
unsigned int /* expected — unused */,
|
| 130 |
-
unsigned int /* phase — unused */,
|
| 131 |
-
unsigned int /* cooperative_grid_sync — unused */) {
|
| 132 |
-
#if __CUDA_ARCH__ >= 900
|
| 133 |
-
// Hopper+ : hardware cluster barrier (~10-40 ns)
|
| 134 |
-
auto cluster = cg::this_cluster();
|
| 135 |
-
cluster.sync();
|
| 136 |
-
#else
|
| 137 |
-
// Pre-Hopper (sm_80, sm_86, sm_89): grid-level cooperative sync.
|
| 138 |
-
// Requires cooperative kernel launch. ~us-ms range, adequate for HTM
|
| 139 |
-
// workload (kernel launch frequency is low).
|
| 140 |
-
grid.sync();
|
| 141 |
-
#endif
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
__device__ static inline unsigned int warp_sum_u32(unsigned int v) {
|
| 145 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 146 |
-
v += __shfl_down_sync(0xffffffffu, v, off);
|
| 147 |
-
}
|
| 148 |
-
return v;
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
// Core kernel body — works for both single-region and batched launches.
|
| 152 |
-
// Single-region: caller passes the one FusedPtrs struct.
|
| 153 |
-
// Batched: each block reads its region's FusedPtrs via blockIdx.y before
|
| 154 |
-
// calling this. State is independent per region (each region owns its own
|
| 155 |
-
// GPU buffers); grid.sync() is the only cross-block primitive and it
|
| 156 |
-
// spans ALL blocks in the grid (harmless over-sync across regions).
|
| 157 |
-
__device__ static inline
|
| 158 |
-
void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
|
| 159 |
-
cg::grid_group grid = cg::this_grid();
|
| 160 |
-
// Cast pointers.
|
| 161 |
-
const unsigned int * __restrict__ syn_bit = (const unsigned int*)P.syn_bit;
|
| 162 |
-
float * __restrict__ syn_perm = (float*)P.syn_perm;
|
| 163 |
-
float * __restrict__ boost = (float*)P.boost;
|
| 164 |
-
float * __restrict__ active_duty = (float*)P.active_duty;
|
| 165 |
-
float * __restrict__ inhibition_threshold = (float*)P.inhibition_threshold;
|
| 166 |
-
unsigned int * __restrict__ seg_cell_id = (unsigned int*)P.seg_cell_id;
|
| 167 |
-
unsigned int * __restrict__ seg_syn_count = (unsigned int*)P.seg_syn_count;
|
| 168 |
-
unsigned int * __restrict__ syn_presyn = (unsigned int*)P.syn_presyn;
|
| 169 |
-
short * __restrict__ tm_syn_perm = (short*)P.tm_syn_perm;
|
| 170 |
-
unsigned int * __restrict__ cell_seg_count = (unsigned int*)P.cell_seg_count;
|
| 171 |
-
unsigned int * __restrict__ cell_active_a = (unsigned int*)P.cell_active_a;
|
| 172 |
-
unsigned int * __restrict__ cell_active_b = (unsigned int*)P.cell_active_b;
|
| 173 |
-
unsigned int * __restrict__ cell_winner_a = (unsigned int*)P.cell_winner_a;
|
| 174 |
-
unsigned int * __restrict__ cell_winner_b = (unsigned int*)P.cell_winner_b;
|
| 175 |
-
const unsigned char * __restrict__ inputs = (const unsigned char*)P.inputs;
|
| 176 |
-
unsigned char * __restrict__ cols_out = (unsigned char*)P.cols_out;
|
| 177 |
-
float * __restrict__ anom_out = (float*)P.anom_out;
|
| 178 |
-
unsigned int * __restrict__ barrier_counters = (unsigned int*)P.barrier_counters;
|
| 179 |
-
unsigned int * __restrict__ step_scratch = (unsigned int*)P.step_scratch;
|
| 180 |
-
|
| 181 |
-
const unsigned int tid = threadIdx.x;
|
| 182 |
-
const unsigned int lane = tid & 31u;
|
| 183 |
-
const unsigned int warp = tid >> 5;
|
| 184 |
-
const unsigned int warps_per_block = blockDim.x >> 5;
|
| 185 |
-
const unsigned int gwarp = blockIdx.x * warps_per_block + warp;
|
| 186 |
-
const unsigned int n_warps = gridDim.x * warps_per_block;
|
| 187 |
-
|
| 188 |
-
const unsigned int n_cols = cfg.n_columns;
|
| 189 |
-
const unsigned int col_lo = (gwarp * n_cols) / n_warps;
|
| 190 |
-
const unsigned int col_hi = ((gwarp + 1) * n_cols) / n_warps;
|
| 191 |
-
|
| 192 |
-
unsigned int phase = 0u;
|
| 193 |
-
|
| 194 |
-
// =========================================================
|
| 195 |
-
// DSMEM: Cluster-distributed shared memory for hot per-column
|
| 196 |
-
// state (inhibition_threshold, boost, active_duty).
|
| 197 |
-
//
|
| 198 |
-
// On Hopper (sm_90+): Each block in the cluster owns a contiguous
|
| 199 |
-
// slice of columns in its own __shared__ arrays. Any block can
|
| 200 |
-
// peer-read another block's slice via cluster.map_shared_rank().
|
| 201 |
-
//
|
| 202 |
-
// On Ampere (sm_86) and other pre-Hopper: No cluster support.
|
| 203 |
-
// Read/write directly from/to global memory (inhibition_threshold,
|
| 204 |
-
// boost, active_duty device pointers). Slightly higher latency but
|
| 205 |
-
// functionally correct.
|
| 206 |
-
// =========================================================
|
| 207 |
-
|
| 208 |
-
#if __CUDA_ARCH__ >= 900
|
| 209 |
-
// Hopper+ cluster path
|
| 210 |
-
auto cluster = cg::this_cluster();
|
| 211 |
-
const unsigned int cluster_block_rank = cluster.block_rank(); // 0..cluster_size-1
|
| 212 |
-
const unsigned int cluster_sz = cluster.num_blocks(); // == gridDim.x (≤16)
|
| 213 |
-
#else
|
| 214 |
-
// Pre-Hopper: no cluster, each block is independent.
|
| 215 |
-
const unsigned int cluster_block_rank = blockIdx.x;
|
| 216 |
-
const unsigned int cluster_sz = gridDim.x;
|
| 217 |
-
#endif
|
| 218 |
-
|
| 219 |
-
// Partition n_cols evenly across cluster blocks.
|
| 220 |
-
// Each block owns cols_per_block columns starting at my_col_start.
|
| 221 |
-
const unsigned int cols_per_block =
|
| 222 |
-
(n_cols + cluster_sz - 1u) / cluster_sz; // ceil div
|
| 223 |
-
const unsigned int my_col_start =
|
| 224 |
-
cluster_block_rank * cols_per_block;
|
| 225 |
-
const unsigned int my_col_end =
|
| 226 |
-
(my_col_start + cols_per_block < n_cols)
|
| 227 |
-
? (my_col_start + cols_per_block) : n_cols; // clamp
|
| 228 |
-
|
| 229 |
-
#if __CUDA_ARCH__ >= 900
|
| 230 |
-
// Cluster-distributed shared memory arrays.
|
| 231 |
-
// Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
|
| 232 |
-
// Peer blocks address into each other's smem via map_shared_rank.
|
| 233 |
-
__shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
|
| 234 |
-
__shared__ float s_boost [COLS_PER_CLUSTER_BLOCK_MAX];
|
| 235 |
-
__shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
|
| 236 |
-
#endif
|
| 237 |
-
|
| 238 |
-
// TMA multicast input staging tile (T9) — HOPPER ONLY.
|
| 239 |
-
//
|
| 240 |
-
// On Hopper: cg::memcpy_async with cluster scope multicasts input to all
|
| 241 |
-
// 16 SMs, reducing DRAM traffic by ~16×.
|
| 242 |
-
// On Ampere: 32 KB smem allocation exceeds per-block budget when
|
| 243 |
-
// cooperatively launched (48 KB total, registers eat the rest). Skip the
|
| 244 |
-
// tile entirely — Stage A reads from GMEM directly (original path).
|
| 245 |
-
#if __CUDA_ARCH__ >= 900
|
| 246 |
-
__shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
|
| 247 |
-
#endif
|
| 248 |
-
|
| 249 |
-
#if __CUDA_ARCH__ >= 900
|
| 250 |
-
// Initial GMEM → smem load (reads state from previous forward call).
|
| 251 |
-
// Each block loads only its own slice; tid strides across the slice.
|
| 252 |
-
for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
|
| 253 |
-
const unsigned int off = c - my_col_start;
|
| 254 |
-
s_inhib_thr [off] = inhibition_threshold[c];
|
| 255 |
-
s_boost [off] = boost[c];
|
| 256 |
-
s_active_duty[off] = active_duty[c];
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
// All blocks in the cluster must finish loading before any block
|
| 260 |
-
// starts reading peer smem inside the T-loop.
|
| 261 |
-
cluster.sync();
|
| 262 |
-
#else
|
| 263 |
-
// Pre-Hopper: no smem caching needed — reads go directly to GMEM.
|
| 264 |
-
// Grid sync ensures all blocks have completed Phase 0 init before T-loop.
|
| 265 |
-
grid.sync();
|
| 266 |
-
#endif
|
| 267 |
-
|
| 268 |
-
const unsigned int S = cfg.synapses_per_col;
|
| 269 |
-
const unsigned int cpc = cfg.cells_per_column;
|
| 270 |
-
const unsigned int SPS = cfg.synapses_per_segment;
|
| 271 |
-
const unsigned int MSC = cfg.max_segments_per_cell;
|
| 272 |
-
|
| 273 |
-
// Main timestep loop.
|
| 274 |
-
for (unsigned int t = 0u; t < cfg.T; t++) {
|
| 275 |
-
const unsigned int inp_off = t * cfg.input_bits;
|
| 276 |
-
const unsigned int col_base_out = t * n_cols;
|
| 277 |
-
|
| 278 |
-
unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
|
| 279 |
-
unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
|
| 280 |
-
unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
|
| 281 |
-
unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;
|
| 282 |
-
|
| 283 |
-
// ---- Phase 0: clear curr bitsets for my cell range ----
|
| 284 |
-
const unsigned int my_cell_lo = col_lo * cpc;
|
| 285 |
-
const unsigned int my_cell_hi = col_hi * cpc;
|
| 286 |
-
if (cpc == 32u) {
|
| 287 |
-
// Fast path: one word per column.
|
| 288 |
-
for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
|
| 289 |
-
curr_active[c] = 0u;
|
| 290 |
-
curr_winner[c] = 0u;
|
| 291 |
-
}
|
| 292 |
-
} else {
|
| 293 |
-
for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
|
| 294 |
-
unsigned int w = cell >> 5;
|
| 295 |
-
unsigned int m = 1u << (cell & 31u);
|
| 296 |
-
atomicAnd(&curr_active[w], ~m);
|
| 297 |
-
atomicAnd(&curr_winner[w], ~m);
|
| 298 |
-
}
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
// Block 0, lane 0, warp 0 resets step-scratch counters.
|
| 302 |
-
if (blockIdx.x == 0u && tid == 0u) {
|
| 303 |
-
step_scratch[0] = 0u;
|
| 304 |
-
step_scratch[1] = 0u;
|
| 305 |
-
}
|
| 306 |
-
|
| 307 |
-
// ---- BARRIER 1 ----
|
| 308 |
-
// Fence: make the above clear-bitsets + scratch writes globally
|
| 309 |
-
// visible before peer blocks observe "barrier arrived".
|
| 310 |
-
__threadfence();
|
| 311 |
-
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 312 |
-
|
| 313 |
-
// =========================================================
|
| 314 |
-
// T9: TMA MULTICAST INPUT STAGING
|
| 315 |
-
//
|
| 316 |
-
// Issue a single cluster-scope async DMA to broadcast this
|
| 317 |
-
// timestep's input slice into s_input_tile across all 16 SMs
|
| 318 |
-
// in the cluster simultaneously. On Hopper sm_90a,
|
| 319 |
-
// cg::memcpy_async with cluster scope maps to the TMA
|
| 320 |
-
// hardware unit (cp.async.bulk.tensor multicast), reducing
|
| 321 |
-
// DRAM input traffic by ~16× vs each block fetching its own
|
| 322 |
-
// copy from GMEM.
|
| 323 |
-
//
|
| 324 |
-
// The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
|
| 325 |
-
// If the tile is too small (custom large input_bits), we fall
|
| 326 |
-
// back to per-thread GMEM reads in Stage A (identical to the
|
| 327 |
-
// original path; use_input_tile==false).
|
| 328 |
-
//
|
| 329 |
-
// Ordering: BARRIER 1 completes before we issue the DMA.
|
| 330 |
-
// The DMA completes before Stage A reads s_input_tile.
|
| 331 |
-
// =========================================================
|
| 332 |
-
#if __CUDA_ARCH__ >= 900
|
| 333 |
-
const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
|
| 334 |
-
if (use_input_tile) {
|
| 335 |
-
auto tb = cg::this_thread_block();
|
| 336 |
-
cg::memcpy_async(tb, s_input_tile,
|
| 337 |
-
inputs + inp_off,
|
| 338 |
-
cfg.input_bits);
|
| 339 |
-
cg::wait(tb);
|
| 340 |
-
cluster.sync();
|
| 341 |
-
}
|
| 342 |
-
#else
|
| 343 |
-
const bool use_input_tile = false;
|
| 344 |
-
#endif
|
| 345 |
-
|
| 346 |
-
// =========================================================
|
| 347 |
-
// STAGE A: Spatial Pooler
|
| 348 |
-
//
|
| 349 |
-
// Hot per-column state (boost, inhibition_threshold,
|
| 350 |
-
// active_duty) is served from cluster DSMEM rather than
|
| 351 |
-
// GMEM for each of the T timesteps. GMEM is written on
|
| 352 |
-
// update so state persists across forward calls.
|
| 353 |
-
// =========================================================
|
| 354 |
-
for (unsigned int c = col_lo; c < col_hi; c++) {
|
| 355 |
-
unsigned int base = c * S;
|
| 356 |
-
unsigned int local = 0u;
|
| 357 |
-
for (unsigned int s = lane; s < S; s += 32u) {
|
| 358 |
-
unsigned int b = syn_bit[base + s];
|
| 359 |
-
float p = syn_perm[base + s];
|
| 360 |
-
// T9: read from cluster-broadcast tile when available;
|
| 361 |
-
// fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
|
| 362 |
-
#if __CUDA_ARCH__ >= 900
|
| 363 |
-
unsigned int inp_byte = use_input_tile
|
| 364 |
-
? (unsigned int)s_input_tile[b]
|
| 365 |
-
: (unsigned int)inputs[inp_off + b];
|
| 366 |
-
#else
|
| 367 |
-
unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
|
| 368 |
-
#endif
|
| 369 |
-
unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
|
| 370 |
-
local += hit;
|
| 371 |
-
}
|
| 372 |
-
unsigned int overlap = warp_sum_u32(local);
|
| 373 |
-
overlap = __shfl_sync(0xffffffffu, overlap, 0);
|
| 374 |
-
|
| 375 |
-
// Read boost + threshold for column c.
|
| 376 |
-
#if __CUDA_ARCH__ >= 900
|
| 377 |
-
// Hopper: read from cluster-distributed shared memory.
|
| 378 |
-
const unsigned int owner_block = c / cols_per_block;
|
| 379 |
-
const unsigned int owner_offset = c - owner_block * cols_per_block;
|
| 380 |
-
float boost_val = cluster.map_shared_rank(s_boost, owner_block)[owner_offset];
|
| 381 |
-
float thr = cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset];
|
| 382 |
-
#else
|
| 383 |
-
// Pre-Hopper: read directly from global memory.
|
| 384 |
-
float boost_val = boost[c];
|
| 385 |
-
float thr = inhibition_threshold[c];
|
| 386 |
-
#endif
|
| 387 |
-
|
| 388 |
-
float boosted = (float)overlap * boost_val;
|
| 389 |
-
unsigned int is_active = (boosted > thr) ? 1u : 0u;
|
| 390 |
-
|
| 391 |
-
if (lane == 0) {
|
| 392 |
-
cols_out[col_base_out + c] = (unsigned char)is_active;
|
| 393 |
-
if (is_active) {
|
| 394 |
-
atomicAdd(&step_scratch[0], 1u);
|
| 395 |
-
}
|
| 396 |
-
}
|
| 397 |
-
|
| 398 |
-
// SP learn (Hebbian) on active columns.
|
| 399 |
-
// T9: use tile for input reads here too.
|
| 400 |
-
if (cfg.learn && is_active) {
|
| 401 |
-
for (unsigned int s = lane; s < S; s += 32u) {
|
| 402 |
-
unsigned int b = syn_bit[base + s];
|
| 403 |
-
float p = syn_perm[base + s];
|
| 404 |
-
#if __CUDA_ARCH__ >= 900
|
| 405 |
-
unsigned int inp_byte = use_input_tile
|
| 406 |
-
? (unsigned int)s_input_tile[b]
|
| 407 |
-
: (unsigned int)inputs[inp_off + b];
|
| 408 |
-
#else
|
| 409 |
-
unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
|
| 410 |
-
#endif
|
| 411 |
-
if (inp_byte != 0u) {
|
| 412 |
-
p += cfg.sp_inc;
|
| 413 |
-
if (p > 1.0f) p = 1.0f;
|
| 414 |
-
} else {
|
| 415 |
-
p -= cfg.sp_dec;
|
| 416 |
-
if (p < 0.0f) p = 0.0f;
|
| 417 |
-
}
|
| 418 |
-
syn_perm[base + s] = p;
|
| 419 |
-
}
|
| 420 |
-
}
|
| 421 |
-
|
| 422 |
-
// active_duty EMA + threshold adaptation.
|
| 423 |
-
// Writes go to both DSMEM (hot path, Hopper only) and GMEM (persistence).
|
| 424 |
-
if (lane == 0) {
|
| 425 |
-
#if __CUDA_ARCH__ >= 900
|
| 426 |
-
float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
|
| 427 |
-
#else
|
| 428 |
-
float ad = active_duty[c];
|
| 429 |
-
#endif
|
| 430 |
-
float sample = is_active ? 1.0f : 0.0f;
|
| 431 |
-
ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;
|
| 432 |
-
|
| 433 |
-
#if __CUDA_ARCH__ >= 900
|
| 434 |
-
// Writeback: peer smem (for next timestep read) + GMEM (persistence).
|
| 435 |
-
cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
|
| 436 |
-
#endif
|
| 437 |
-
active_duty[c] = ad;
|
| 438 |
-
|
| 439 |
-
// Threshold steers toward target sparsity.
|
| 440 |
-
float err = ad - cfg.sparsity_target;
|
| 441 |
-
float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
|
| 442 |
-
if (new_thr < 0.1f) new_thr = 0.1f;
|
| 443 |
-
if (new_thr > 1000.0f) new_thr = 1000.0f;
|
| 444 |
-
|
| 445 |
-
#if __CUDA_ARCH__ >= 900
|
| 446 |
-
// Writeback: peer smem (for next timestep read) + GMEM (persistence).
|
| 447 |
-
cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
|
| 448 |
-
#endif
|
| 449 |
-
inhibition_threshold[c] = new_thr;
|
| 450 |
-
}
|
| 451 |
-
}
|
| 452 |
-
|
| 453 |
-
// ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
|
| 454 |
-
//
|
| 455 |
-
// On Hopper: cluster.sync() ensures all peer smem writes from this
|
| 456 |
-
// timestep are visible to all blocks before Stage B / next t.
|
| 457 |
-
// On pre-Hopper: no smem peer writes occur (all state in GMEM),
|
| 458 |
-
// so no extra sync needed here — the grid barrier below suffices.
|
| 459 |
-
#if __CUDA_ARCH__ >= 900
|
| 460 |
-
cluster.sync();
|
| 461 |
-
#endif
|
| 462 |
-
|
| 463 |
-
// ---- BARRIER 2: SP active_mask must be visible before TM reads ----
|
| 464 |
-
// Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
|
| 465 |
-
// writes to global memory before peers advance past this barrier.
|
| 466 |
-
__threadfence();
|
| 467 |
-
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 468 |
-
|
| 469 |
-
// =========================================================
|
| 470 |
-
// STAGE B: Temporal Memory
|
| 471 |
-
// =========================================================
|
| 472 |
-
for (unsigned int c = col_lo; c < col_hi; c++) {
|
| 473 |
-
unsigned int col_active = cols_out[col_base_out + c];
|
| 474 |
-
if (col_active == 0u) continue;
|
| 475 |
-
|
| 476 |
-
unsigned int base_cell = c * cpc;
|
| 477 |
-
unsigned int any_predicted = 0u;
|
| 478 |
-
unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
|
| 479 |
-
unsigned int best_pot_count = 0u;
|
| 480 |
-
|
| 481 |
-
for (unsigned int k = 0u; k < cpc; k++) {
|
| 482 |
-
unsigned int cell = base_cell + k;
|
| 483 |
-
unsigned int n_segs_here = cell_seg_count[cell];
|
| 484 |
-
if (n_segs_here > MSC) n_segs_here = MSC;
|
| 485 |
-
if (n_segs_here == 0u) continue;
|
| 486 |
-
|
| 487 |
-
unsigned int seg_base_id = cell * MSC;
|
| 488 |
-
unsigned int cell_is_predictive = 0u;
|
| 489 |
-
|
| 490 |
-
for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
|
| 491 |
-
unsigned int seg = seg_base_id + ls;
|
| 492 |
-
unsigned int n_syn = seg_syn_count[seg];
|
| 493 |
-
if (n_syn == 0u) continue;
|
| 494 |
-
unsigned int syn_base = seg * SPS;
|
| 495 |
-
|
| 496 |
-
unsigned int l_conn = 0u;
|
| 497 |
-
unsigned int l_pot = 0u;
|
| 498 |
-
for (unsigned int s = lane; s < n_syn; s += 32u) {
|
| 499 |
-
unsigned int presyn = syn_presyn[syn_base + s];
|
| 500 |
-
unsigned int w = prev_active[presyn >> 5];
|
| 501 |
-
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 502 |
-
if (bit) {
|
| 503 |
-
l_pot += 1u;
|
| 504 |
-
int p = (int)tm_syn_perm[syn_base + s];
|
| 505 |
-
if (p >= cfg.conn_thr_i16) l_conn += 1u;
|
| 506 |
-
}
|
| 507 |
-
}
|
| 508 |
-
unsigned int tot_conn = warp_sum_u32(l_conn);
|
| 509 |
-
unsigned int tot_pot = warp_sum_u32(l_pot);
|
| 510 |
-
tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
|
| 511 |
-
tot_pot = __shfl_sync(0xffffffffu, tot_pot, 0);
|
| 512 |
-
|
| 513 |
-
if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
|
| 514 |
-
if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
|
| 515 |
-
best_pot_count = tot_pot;
|
| 516 |
-
best_seg_id_for_grow = seg;
|
| 517 |
-
}
|
| 518 |
-
|
| 519 |
-
// Reinforce predicted-and-correct segment.
|
| 520 |
-
if (cfg.learn && tot_conn >= cfg.activation_threshold) {
|
| 521 |
-
for (unsigned int s = lane; s < n_syn; s += 32u) {
|
| 522 |
-
unsigned int presyn = syn_presyn[syn_base + s];
|
| 523 |
-
unsigned int w = prev_active[presyn >> 5];
|
| 524 |
-
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 525 |
-
int p = (int)tm_syn_perm[syn_base + s];
|
| 526 |
-
if (bit) {
|
| 527 |
-
int np = p + cfg.perm_inc_i16;
|
| 528 |
-
if (np > 32767) np = 32767;
|
| 529 |
-
tm_syn_perm[syn_base + s] = (short)np;
|
| 530 |
-
} else {
|
| 531 |
-
int np = p - cfg.perm_dec_i16;
|
| 532 |
-
if (np < 0) np = 0;
|
| 533 |
-
tm_syn_perm[syn_base + s] = (short)np;
|
| 534 |
-
}
|
| 535 |
-
}
|
| 536 |
-
}
|
| 537 |
-
}
|
| 538 |
-
|
| 539 |
-
if (cell_is_predictive) {
|
| 540 |
-
any_predicted = 1u;
|
| 541 |
-
if (lane == 0) {
|
| 542 |
-
unsigned int w = cell >> 5;
|
| 543 |
-
unsigned int m = 1u << (cell & 31u);
|
| 544 |
-
atomicOr(&curr_active[w], m);
|
| 545 |
-
atomicOr(&curr_winner[w], m);
|
| 546 |
-
}
|
| 547 |
-
}
|
| 548 |
-
}
|
| 549 |
-
|
| 550 |
-
// BURST if no predicted.
|
| 551 |
-
if (!any_predicted) {
|
| 552 |
-
if (lane == 0) {
|
| 553 |
-
for (unsigned int k = 0u; k < cpc; k++) {
|
| 554 |
-
unsigned int cell = base_cell + k;
|
| 555 |
-
unsigned int w = cell >> 5;
|
| 556 |
-
unsigned int m = 1u << (cell & 31u);
|
| 557 |
-
atomicOr(&curr_active[w], m);
|
| 558 |
-
}
|
| 559 |
-
unsigned int win = base_cell;
|
| 560 |
-
unsigned int ww = win >> 5;
|
| 561 |
-
unsigned int wm = 1u << (win & 31u);
|
| 562 |
-
atomicOr(&curr_winner[ww], wm);
|
| 563 |
-
atomicAdd(&step_scratch[1], 1u);
|
| 564 |
-
}
|
| 565 |
-
|
| 566 |
-
if (cfg.learn) {
|
| 567 |
-
unsigned int target_seg;
|
| 568 |
-
unsigned int existing_syn;
|
| 569 |
-
if (best_seg_id_for_grow != 0xFFFFFFFFu) {
|
| 570 |
-
// Reuse best matching segment.
|
| 571 |
-
target_seg = best_seg_id_for_grow;
|
| 572 |
-
existing_syn = seg_syn_count[target_seg];
|
| 573 |
-
target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
|
| 574 |
-
existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);
|
| 575 |
-
|
| 576 |
-
// Reinforce its existing synapses.
|
| 577 |
-
unsigned int syn_base = target_seg * SPS;
|
| 578 |
-
for (unsigned int s = lane; s < existing_syn; s += 32u) {
|
| 579 |
-
unsigned int presyn = syn_presyn[syn_base + s];
|
| 580 |
-
unsigned int w = prev_active[presyn >> 5];
|
| 581 |
-
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 582 |
-
int p = (int)tm_syn_perm[syn_base + s];
|
| 583 |
-
if (bit) {
|
| 584 |
-
int np = p + cfg.perm_inc_i16;
|
| 585 |
-
if (np > 32767) np = 32767;
|
| 586 |
-
tm_syn_perm[syn_base + s] = (short)np;
|
| 587 |
-
} else {
|
| 588 |
-
int np = p - cfg.perm_dec_i16;
|
| 589 |
-
if (np < 0) np = 0;
|
| 590 |
-
tm_syn_perm[syn_base + s] = (short)np;
|
| 591 |
-
}
|
| 592 |
-
}
|
| 593 |
-
} else {
|
| 594 |
-
// Allocate new segment on winner cell (cell 0 of col).
|
| 595 |
-
unsigned int new_seg = 0u;
|
| 596 |
-
if (lane == 0) {
|
| 597 |
-
unsigned int winner_cell = base_cell;
|
| 598 |
-
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
|
| 599 |
-
if (slot >= MSC) slot = slot % MSC;
|
| 600 |
-
new_seg = winner_cell * MSC + slot;
|
| 601 |
-
seg_cell_id[new_seg] = winner_cell;
|
| 602 |
-
seg_syn_count[new_seg] = 0u;
|
| 603 |
-
}
|
| 604 |
-
target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
|
| 605 |
-
existing_syn = 0u;
|
| 606 |
-
}
|
| 607 |
-
|
| 608 |
-
// Grow synapses to prev_winner cells — lane 0 serialized.
|
| 609 |
-
unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
|
| 610 |
-
unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
|
| 611 |
-
if (lane == 0 && max_grow > 0u) {
|
| 612 |
-
unsigned int syn_base = target_seg * SPS;
|
| 613 |
-
unsigned int grown = 0u;
|
| 614 |
-
unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
|
| 615 |
-
for (unsigned int w_off = 0u;
|
| 616 |
-
w_off < cfg.bits_words && grown < max_grow;
|
| 617 |
-
w_off++) {
|
| 618 |
-
unsigned int widx = (start_off + w_off) % cfg.bits_words;
|
| 619 |
-
unsigned int word = prev_winner[widx];
|
| 620 |
-
while (word != 0u && grown < max_grow) {
|
| 621 |
-
unsigned int bit_pos = __ffs(word) - 1u;
|
| 622 |
-
word &= ~(1u << bit_pos);
|
| 623 |
-
unsigned int cell_id = widx * 32u + bit_pos;
|
| 624 |
-
if (cell_id >= cfg.n_cells) continue;
|
| 625 |
-
bool exists = false;
|
| 626 |
-
for (unsigned int es = 0u; es < existing_syn + grown; es++) {
|
| 627 |
-
if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
|
| 628 |
-
}
|
| 629 |
-
if (exists) continue;
|
| 630 |
-
unsigned int write_idx = existing_syn + grown;
|
| 631 |
-
if (write_idx >= SPS) break;
|
| 632 |
-
syn_presyn[syn_base + write_idx] = cell_id;
|
| 633 |
-
tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
|
| 634 |
-
grown++;
|
| 635 |
-
}
|
| 636 |
-
}
|
| 637 |
-
if (grown > 0u) {
|
| 638 |
-
seg_syn_count[target_seg] = existing_syn + grown;
|
| 639 |
-
}
|
| 640 |
-
}
|
| 641 |
-
}
|
| 642 |
-
}
|
| 643 |
-
}
|
| 644 |
-
|
| 645 |
-
// ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
|
| 646 |
-
// Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
|
| 647 |
-
// seg_syn_count + syn_presyn before peers advance and consume them as
|
| 648 |
-
// prev_active/prev_winner at t+1.
|
| 649 |
-
__threadfence();
|
| 650 |
-
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 651 |
-
|
| 652 |
-
// Write anomaly for step t.
|
| 653 |
-
if (blockIdx.x == 0u && tid == 0u) {
|
| 654 |
-
unsigned int total = step_scratch[0];
|
| 655 |
-
unsigned int bad = step_scratch[1];
|
| 656 |
-
float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
|
| 657 |
-
anom_out[t] = anom;
|
| 658 |
-
}
|
| 659 |
-
}
|
| 660 |
-
}
|
| 661 |
-
|
| 662 |
-
// Single-region kernel (legacy call site).
|
| 663 |
-
__global__ __launch_bounds__(256, 2)
|
| 664 |
-
void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
|
| 665 |
-
htm_fused_step_body(P, cfg);
|
| 666 |
-
}
|
| 667 |
-
|
| 668 |
-
// Batched kernel: one cooperative launch for B regions. grid.y = B,
|
| 669 |
-
// grid.x = per-region block count. Each block reads its region's
|
| 670 |
-
// FusedPtrs from the device array via blockIdx.y.
|
| 671 |
-
__global__ __launch_bounds__(256, 2)
|
| 672 |
-
void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
|
| 673 |
-
const FusedPtrs P = P_arr[blockIdx.y];
|
| 674 |
-
htm_fused_step_body(P, cfg);
|
| 675 |
-
}
|
| 676 |
-
|
| 677 |
-
} // extern "C"
|
|
|
|
| 1 |
+
// Fused HTM megakernel — SP + TM, all T timesteps in a single launch.
|
| 2 |
+
//
|
| 3 |
+
// Design rationale:
|
| 4 |
+
// - Global top-K column selection requires cross-block synchronization at
|
| 5 |
+
// every timestep (grid.sync is unreliable on WSL2/sm_86 without rdc=true).
|
| 6 |
+
// - Replace with per-column threshold activation using local lateral
|
| 7 |
+
// inhibition: column c activates if overlap[c]*boost[c] > threshold[c].
|
| 8 |
+
// Threshold is a per-column running-EMA learned scalar that steers the
|
| 9 |
+
// column's long-run activation rate toward the global sparsity target.
|
| 10 |
+
// - This is biologically grounded (GABAergic local inhibition) and supported
|
| 11 |
+
// by HTM theory (duty-cycle boost already drives this loop; we just
|
| 12 |
+
// change which lever the EMA pulls).
|
| 13 |
+
//
|
| 14 |
+
// Launch shape:
|
| 15 |
+
// grid = min(device SM count, 16) // hard cap — see below
|
| 16 |
+
// block = 1024 threads = 32 warps
|
| 17 |
+
// Each warp of 32 owns a contiguous column slice (n_columns / total_warps).
|
| 18 |
+
//
|
| 19 |
+
// Cross-block coherence:
|
| 20 |
+
// - Ping-pong buffers for cell_active/cell_winner: write _a at even t,
|
| 21 |
+
// read _b; reversed at odd t.
|
| 22 |
+
// - Preferred path: cooperative launch + hardware whole-grid sync.
|
| 23 |
+
// - Fallback path: software 3-slot rotating grid barrier for devices/drivers
|
| 24 |
+
// that cannot do cooperative launch.
|
| 25 |
+
//
|
| 26 |
+
// 2026-04-16: grid_dim reduced from 28 to 16 after deadlock RCA. The previous
|
| 27 |
+
// cap of 28 relied on all blocks being concurrently resident on a 30-SM RTX
|
| 28 |
+
// 3060 Laptop. Under thermal throttling effective residency dropped to ~20-24,
|
| 29 |
+
// leaving scheduled blocks spinning on the software grid barrier waiting for
|
| 30 |
+
// peer blocks that would never run. 16 blocks is below any realistic residency
|
| 31 |
+
// floor and preserves enough warp parallelism (16*32 = 512 warps) to saturate
|
| 32 |
+
// memory bandwidth on the spatial-pooler stage.
|
| 33 |
+
//
|
| 34 |
+
// Kernel signature uses struct-by-value for pointers and config to stay
|
| 35 |
+
// inside cudarc's launch-arg count limit.
|
| 36 |
+
|
| 37 |
+
#include <cooperative_groups.h>
|
| 38 |
+
#include <cooperative_groups/memcpy_async.h>
|
| 39 |
+
|
| 40 |
+
namespace cg = cooperative_groups;
|
| 41 |
+
|
| 42 |
+
// Maximum columns owned per cluster-block in DSMEM.
|
| 43 |
+
// Supports n_columns up to COLS_PER_CLUSTER_BLOCK_MAX * cluster_size.
|
| 44 |
+
// At cluster_size=16: supports up to 256*16=4096 columns.
|
| 45 |
+
// Each array costs 256*4 = 1024 bytes; three arrays = 3072 bytes per SM —
|
| 46 |
+
// well under the 228 KB H200 shared-memory cap.
|
| 47 |
+
#define COLS_PER_CLUSTER_BLOCK_MAX 256u
|
| 48 |
+
|
| 49 |
+
// Maximum input_bits supported by the TMA-multicast staging tile.
|
| 50 |
+
// At 32 KB this covers the production SDR width (16384 bits) with 2× headroom.
|
| 51 |
+
// Total shared per SM: 32768 (tile) + 3072 (DSMEM float arrays) = ~35 KB —
|
| 52 |
+
// well under the 228 KB H200 limit.
|
| 53 |
+
//
|
| 54 |
+
// Expected speedup from TMA multicast input staging (T9/T11):
|
| 55 |
+
// - Without staging: 16 SMs × T × (input_bits GMEM reads per timestep)
|
| 56 |
+
// - With staging: 1 TMA DMA per timestep, shared reads from L1 thereafter
|
| 57 |
+
// - Theoretical DRAM bandwidth reduction: ~16× on input reads
|
| 58 |
+
// - Wall-clock reduction estimate: -20 to -40 ms from reduced input fetch latency
|
| 59 |
+
#define INPUT_BITS_MAX 32768u
|
| 60 |
+
|
| 61 |
+
extern "C" {
|
| 62 |
+
|
| 63 |
+
struct FusedPtrs {
|
| 64 |
+
unsigned long long syn_bit;
|
| 65 |
+
unsigned long long syn_perm;
|
| 66 |
+
unsigned long long boost;
|
| 67 |
+
unsigned long long active_duty;
|
| 68 |
+
unsigned long long inhibition_threshold;
|
| 69 |
+
unsigned long long seg_cell_id;
|
| 70 |
+
unsigned long long seg_syn_count;
|
| 71 |
+
unsigned long long syn_presyn;
|
| 72 |
+
unsigned long long tm_syn_perm;
|
| 73 |
+
unsigned long long cell_seg_count;
|
| 74 |
+
unsigned long long cell_active_a;
|
| 75 |
+
unsigned long long cell_active_b;
|
| 76 |
+
unsigned long long cell_winner_a;
|
| 77 |
+
unsigned long long cell_winner_b;
|
| 78 |
+
unsigned long long inputs;
|
| 79 |
+
unsigned long long cols_out;
|
| 80 |
+
unsigned long long anom_out;
|
| 81 |
+
unsigned long long barrier_counters;
|
| 82 |
+
unsigned long long step_scratch;
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
struct FusedConfig {
|
| 86 |
+
// SP constants
|
| 87 |
+
unsigned int input_bits;
|
| 88 |
+
unsigned int n_columns;
|
| 89 |
+
unsigned int synapses_per_col;
|
| 90 |
+
float conn_thr;
|
| 91 |
+
float sp_inc;
|
| 92 |
+
float sp_dec;
|
| 93 |
+
float sparsity_target;
|
| 94 |
+
float duty_alpha;
|
| 95 |
+
float thr_adapt_rate;
|
| 96 |
+
// TM constants
|
| 97 |
+
unsigned int cells_per_column;
|
| 98 |
+
unsigned int n_cells;
|
| 99 |
+
unsigned int bits_words;
|
| 100 |
+
unsigned int max_segments_per_cell;
|
| 101 |
+
unsigned int synapses_per_segment;
|
| 102 |
+
unsigned int activation_threshold;
|
| 103 |
+
unsigned int learning_threshold;
|
| 104 |
+
unsigned int max_new_synapses;
|
| 105 |
+
int conn_thr_i16;
|
| 106 |
+
int perm_inc_i16;
|
| 107 |
+
int perm_dec_i16;
|
| 108 |
+
int predicted_seg_dec_i16;
|
| 109 |
+
int initial_perm_i16;
|
| 110 |
+
// Loop constants
|
| 111 |
+
unsigned int T;
|
| 112 |
+
unsigned int learn;
|
| 113 |
+
unsigned int iter_seed;
|
| 114 |
+
unsigned int cooperative_grid_sync;
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
// Hardware cluster barrier using Hopper sm_90a cooperative_groups::this_cluster().sync().
|
| 118 |
+
// Replaces the former software Decoupled Look-Back (DLB) atomic-spin barrier.
|
| 119 |
+
//
|
| 120 |
+
// cluster::sync() is a single PTX instruction (barrier.cluster) that resolves
|
| 121 |
+
// in ~10-40 ns inside the cluster, with no device-level serialization.
|
| 122 |
+
// Multiple clusters (one per HTM region) run fully concurrently — bounded
|
| 123 |
+
// only by SM count (8 clusters × 16 SMs = 128 ≤ 132 on H200).
|
| 124 |
+
//
|
| 125 |
+
// The flags / expected / phase / cooperative_grid_sync parameters are kept
|
| 126 |
+
// in the signature for call-site compatibility but are unused.
|
| 127 |
+
__device__ static inline void fused_grid_barrier(cg::grid_group grid,
|
| 128 |
+
unsigned int * /* flags — unused */,
|
| 129 |
+
unsigned int /* expected — unused */,
|
| 130 |
+
unsigned int /* phase — unused */,
|
| 131 |
+
unsigned int /* cooperative_grid_sync — unused */) {
|
| 132 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 133 |
+
// Hopper+ : hardware cluster barrier (~10-40 ns)
|
| 134 |
+
auto cluster = cg::this_cluster();
|
| 135 |
+
cluster.sync();
|
| 136 |
+
#else
|
| 137 |
+
// Pre-Hopper (sm_80, sm_86, sm_89): grid-level cooperative sync.
|
| 138 |
+
// Requires cooperative kernel launch. ~us-ms range, adequate for HTM
|
| 139 |
+
// workload (kernel launch frequency is low).
|
| 140 |
+
grid.sync();
|
| 141 |
+
#endif
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
__device__ static inline unsigned int warp_sum_u32(unsigned int v) {
|
| 145 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 146 |
+
v += __shfl_down_sync(0xffffffffu, v, off);
|
| 147 |
+
}
|
| 148 |
+
return v;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
// Core kernel body — works for both single-region and batched launches.
|
| 152 |
+
// Single-region: caller passes the one FusedPtrs struct.
|
| 153 |
+
// Batched: each block reads its region's FusedPtrs via blockIdx.y before
|
| 154 |
+
// calling this. State is independent per region (each region owns its own
|
| 155 |
+
// GPU buffers); grid.sync() is the only cross-block primitive and it
|
| 156 |
+
// spans ALL blocks in the grid (harmless over-sync across regions).
|
| 157 |
+
__device__ static inline
|
| 158 |
+
void htm_fused_step_body(const FusedPtrs& P, const FusedConfig& cfg) {
|
| 159 |
+
cg::grid_group grid = cg::this_grid();
|
| 160 |
+
// Cast pointers.
|
| 161 |
+
const unsigned int * __restrict__ syn_bit = (const unsigned int*)P.syn_bit;
|
| 162 |
+
float * __restrict__ syn_perm = (float*)P.syn_perm;
|
| 163 |
+
float * __restrict__ boost = (float*)P.boost;
|
| 164 |
+
float * __restrict__ active_duty = (float*)P.active_duty;
|
| 165 |
+
float * __restrict__ inhibition_threshold = (float*)P.inhibition_threshold;
|
| 166 |
+
unsigned int * __restrict__ seg_cell_id = (unsigned int*)P.seg_cell_id;
|
| 167 |
+
unsigned int * __restrict__ seg_syn_count = (unsigned int*)P.seg_syn_count;
|
| 168 |
+
unsigned int * __restrict__ syn_presyn = (unsigned int*)P.syn_presyn;
|
| 169 |
+
short * __restrict__ tm_syn_perm = (short*)P.tm_syn_perm;
|
| 170 |
+
unsigned int * __restrict__ cell_seg_count = (unsigned int*)P.cell_seg_count;
|
| 171 |
+
unsigned int * __restrict__ cell_active_a = (unsigned int*)P.cell_active_a;
|
| 172 |
+
unsigned int * __restrict__ cell_active_b = (unsigned int*)P.cell_active_b;
|
| 173 |
+
unsigned int * __restrict__ cell_winner_a = (unsigned int*)P.cell_winner_a;
|
| 174 |
+
unsigned int * __restrict__ cell_winner_b = (unsigned int*)P.cell_winner_b;
|
| 175 |
+
const unsigned char * __restrict__ inputs = (const unsigned char*)P.inputs;
|
| 176 |
+
unsigned char * __restrict__ cols_out = (unsigned char*)P.cols_out;
|
| 177 |
+
float * __restrict__ anom_out = (float*)P.anom_out;
|
| 178 |
+
unsigned int * __restrict__ barrier_counters = (unsigned int*)P.barrier_counters;
|
| 179 |
+
unsigned int * __restrict__ step_scratch = (unsigned int*)P.step_scratch;
|
| 180 |
+
|
| 181 |
+
const unsigned int tid = threadIdx.x;
|
| 182 |
+
const unsigned int lane = tid & 31u;
|
| 183 |
+
const unsigned int warp = tid >> 5;
|
| 184 |
+
const unsigned int warps_per_block = blockDim.x >> 5;
|
| 185 |
+
const unsigned int gwarp = blockIdx.x * warps_per_block + warp;
|
| 186 |
+
const unsigned int n_warps = gridDim.x * warps_per_block;
|
| 187 |
+
|
| 188 |
+
const unsigned int n_cols = cfg.n_columns;
|
| 189 |
+
const unsigned int col_lo = (gwarp * n_cols) / n_warps;
|
| 190 |
+
const unsigned int col_hi = ((gwarp + 1) * n_cols) / n_warps;
|
| 191 |
+
|
| 192 |
+
unsigned int phase = 0u;
|
| 193 |
+
|
| 194 |
+
// =========================================================
|
| 195 |
+
// DSMEM: Cluster-distributed shared memory for hot per-column
|
| 196 |
+
// state (inhibition_threshold, boost, active_duty).
|
| 197 |
+
//
|
| 198 |
+
// On Hopper (sm_90+): Each block in the cluster owns a contiguous
|
| 199 |
+
// slice of columns in its own __shared__ arrays. Any block can
|
| 200 |
+
// peer-read another block's slice via cluster.map_shared_rank().
|
| 201 |
+
//
|
| 202 |
+
// On Ampere (sm_86) and other pre-Hopper: No cluster support.
|
| 203 |
+
// Read/write directly from/to global memory (inhibition_threshold,
|
| 204 |
+
// boost, active_duty device pointers). Slightly higher latency but
|
| 205 |
+
// functionally correct.
|
| 206 |
+
// =========================================================
|
| 207 |
+
|
| 208 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 209 |
+
// Hopper+ cluster path
|
| 210 |
+
auto cluster = cg::this_cluster();
|
| 211 |
+
const unsigned int cluster_block_rank = cluster.block_rank(); // 0..cluster_size-1
|
| 212 |
+
const unsigned int cluster_sz = cluster.num_blocks(); // == gridDim.x (≤16)
|
| 213 |
+
#else
|
| 214 |
+
// Pre-Hopper: no cluster, each block is independent.
|
| 215 |
+
const unsigned int cluster_block_rank = blockIdx.x;
|
| 216 |
+
const unsigned int cluster_sz = gridDim.x;
|
| 217 |
+
#endif
|
| 218 |
+
|
| 219 |
+
// Partition n_cols evenly across cluster blocks.
|
| 220 |
+
// Each block owns cols_per_block columns starting at my_col_start.
|
| 221 |
+
const unsigned int cols_per_block =
|
| 222 |
+
(n_cols + cluster_sz - 1u) / cluster_sz; // ceil div
|
| 223 |
+
const unsigned int my_col_start =
|
| 224 |
+
cluster_block_rank * cols_per_block;
|
| 225 |
+
const unsigned int my_col_end =
|
| 226 |
+
(my_col_start + cols_per_block < n_cols)
|
| 227 |
+
? (my_col_start + cols_per_block) : n_cols; // clamp
|
| 228 |
+
|
| 229 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 230 |
+
// Cluster-distributed shared memory arrays.
|
| 231 |
+
// Each block holds at most COLS_PER_CLUSTER_BLOCK_MAX floats per array.
|
| 232 |
+
// Peer blocks address into each other's smem via map_shared_rank.
|
| 233 |
+
__shared__ float s_inhib_thr [COLS_PER_CLUSTER_BLOCK_MAX];
|
| 234 |
+
__shared__ float s_boost [COLS_PER_CLUSTER_BLOCK_MAX];
|
| 235 |
+
__shared__ float s_active_duty[COLS_PER_CLUSTER_BLOCK_MAX];
|
| 236 |
+
#endif
|
| 237 |
+
|
| 238 |
+
// TMA multicast input staging tile (T9) — HOPPER ONLY.
|
| 239 |
+
//
|
| 240 |
+
// On Hopper: cg::memcpy_async with cluster scope multicasts input to all
|
| 241 |
+
// 16 SMs, reducing DRAM traffic by ~16×.
|
| 242 |
+
// On Ampere: 32 KB smem allocation exceeds per-block budget when
|
| 243 |
+
// cooperatively launched (48 KB total, registers eat the rest). Skip the
|
| 244 |
+
// tile entirely — Stage A reads from GMEM directly (original path).
|
| 245 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 246 |
+
__shared__ __align__(16) unsigned char s_input_tile[INPUT_BITS_MAX];
|
| 247 |
+
#endif
|
| 248 |
+
|
| 249 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 250 |
+
// Initial GMEM → smem load (reads state from previous forward call).
|
| 251 |
+
// Each block loads only its own slice; tid strides across the slice.
|
| 252 |
+
for (unsigned int c = my_col_start + tid; c < my_col_end; c += blockDim.x) {
|
| 253 |
+
const unsigned int off = c - my_col_start;
|
| 254 |
+
s_inhib_thr [off] = inhibition_threshold[c];
|
| 255 |
+
s_boost [off] = boost[c];
|
| 256 |
+
s_active_duty[off] = active_duty[c];
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// All blocks in the cluster must finish loading before any block
|
| 260 |
+
// starts reading peer smem inside the T-loop.
|
| 261 |
+
cluster.sync();
|
| 262 |
+
#else
|
| 263 |
+
// Pre-Hopper: no smem caching needed — reads go directly to GMEM.
|
| 264 |
+
// Grid sync ensures all blocks have completed Phase 0 init before T-loop.
|
| 265 |
+
grid.sync();
|
| 266 |
+
#endif
|
| 267 |
+
|
| 268 |
+
const unsigned int S = cfg.synapses_per_col;
|
| 269 |
+
const unsigned int cpc = cfg.cells_per_column;
|
| 270 |
+
const unsigned int SPS = cfg.synapses_per_segment;
|
| 271 |
+
const unsigned int MSC = cfg.max_segments_per_cell;
|
| 272 |
+
|
| 273 |
+
// Main timestep loop.
|
| 274 |
+
for (unsigned int t = 0u; t < cfg.T; t++) {
|
| 275 |
+
const unsigned int inp_off = t * cfg.input_bits;
|
| 276 |
+
const unsigned int col_base_out = t * n_cols;
|
| 277 |
+
|
| 278 |
+
unsigned int * curr_active = (t & 1u) ? cell_active_b : cell_active_a;
|
| 279 |
+
unsigned int * prev_active = (t & 1u) ? cell_active_a : cell_active_b;
|
| 280 |
+
unsigned int * curr_winner = (t & 1u) ? cell_winner_b : cell_winner_a;
|
| 281 |
+
unsigned int * prev_winner = (t & 1u) ? cell_winner_a : cell_winner_b;
|
| 282 |
+
|
| 283 |
+
// ---- Phase 0: clear curr bitsets for my cell range ----
|
| 284 |
+
const unsigned int my_cell_lo = col_lo * cpc;
|
| 285 |
+
const unsigned int my_cell_hi = col_hi * cpc;
|
| 286 |
+
if (cpc == 32u) {
|
| 287 |
+
// Fast path: one word per column.
|
| 288 |
+
for (unsigned int c = col_lo + lane; c < col_hi; c += 32u) {
|
| 289 |
+
curr_active[c] = 0u;
|
| 290 |
+
curr_winner[c] = 0u;
|
| 291 |
+
}
|
| 292 |
+
} else {
|
| 293 |
+
for (unsigned int cell = my_cell_lo + lane; cell < my_cell_hi; cell += 32u) {
|
| 294 |
+
unsigned int w = cell >> 5;
|
| 295 |
+
unsigned int m = 1u << (cell & 31u);
|
| 296 |
+
atomicAnd(&curr_active[w], ~m);
|
| 297 |
+
atomicAnd(&curr_winner[w], ~m);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
// Block 0, lane 0, warp 0 resets step-scratch counters.
|
| 302 |
+
if (blockIdx.x == 0u && tid == 0u) {
|
| 303 |
+
step_scratch[0] = 0u;
|
| 304 |
+
step_scratch[1] = 0u;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// ---- BARRIER 1 ----
|
| 308 |
+
// Fence: make the above clear-bitsets + scratch writes globally
|
| 309 |
+
// visible before peer blocks observe "barrier arrived".
|
| 310 |
+
__threadfence();
|
| 311 |
+
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 312 |
+
|
| 313 |
+
// =========================================================
|
| 314 |
+
// T9: TMA MULTICAST INPUT STAGING
|
| 315 |
+
//
|
| 316 |
+
// Issue a single cluster-scope async DMA to broadcast this
|
| 317 |
+
// timestep's input slice into s_input_tile across all 16 SMs
|
| 318 |
+
// in the cluster simultaneously. On Hopper sm_90a,
|
| 319 |
+
// cg::memcpy_async with cluster scope maps to the TMA
|
| 320 |
+
// hardware unit (cp.async.bulk.tensor multicast), reducing
|
| 321 |
+
// DRAM input traffic by ~16× vs each block fetching its own
|
| 322 |
+
// copy from GMEM.
|
| 323 |
+
//
|
| 324 |
+
// The staging is gated on cfg.input_bits <= INPUT_BITS_MAX.
|
| 325 |
+
// If the tile is too small (custom large input_bits), we fall
|
| 326 |
+
// back to per-thread GMEM reads in Stage A (identical to the
|
| 327 |
+
// original path; use_input_tile==false).
|
| 328 |
+
//
|
| 329 |
+
// Ordering: BARRIER 1 completes before we issue the DMA.
|
| 330 |
+
// The DMA completes before Stage A reads s_input_tile.
|
| 331 |
+
// =========================================================
|
| 332 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 333 |
+
const bool use_input_tile = (cfg.input_bits <= INPUT_BITS_MAX);
|
| 334 |
+
if (use_input_tile) {
|
| 335 |
+
auto tb = cg::this_thread_block();
|
| 336 |
+
cg::memcpy_async(tb, s_input_tile,
|
| 337 |
+
inputs + inp_off,
|
| 338 |
+
cfg.input_bits);
|
| 339 |
+
cg::wait(tb);
|
| 340 |
+
cluster.sync();
|
| 341 |
+
}
|
| 342 |
+
#else
|
| 343 |
+
const bool use_input_tile = false;
|
| 344 |
+
#endif
|
| 345 |
+
|
| 346 |
+
// =========================================================
|
| 347 |
+
// STAGE A: Spatial Pooler
|
| 348 |
+
//
|
| 349 |
+
// Hot per-column state (boost, inhibition_threshold,
|
| 350 |
+
// active_duty) is served from cluster DSMEM rather than
|
| 351 |
+
// GMEM for each of the T timesteps. GMEM is written on
|
| 352 |
+
// update so state persists across forward calls.
|
| 353 |
+
// =========================================================
|
| 354 |
+
for (unsigned int c = col_lo; c < col_hi; c++) {
|
| 355 |
+
unsigned int base = c * S;
|
| 356 |
+
unsigned int local = 0u;
|
| 357 |
+
for (unsigned int s = lane; s < S; s += 32u) {
|
| 358 |
+
unsigned int b = syn_bit[base + s];
|
| 359 |
+
float p = syn_perm[base + s];
|
| 360 |
+
// T9: read from cluster-broadcast tile when available;
|
| 361 |
+
// fall back to direct GMEM when input_bits > INPUT_BITS_MAX.
|
| 362 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 363 |
+
unsigned int inp_byte = use_input_tile
|
| 364 |
+
? (unsigned int)s_input_tile[b]
|
| 365 |
+
: (unsigned int)inputs[inp_off + b];
|
| 366 |
+
#else
|
| 367 |
+
unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
|
| 368 |
+
#endif
|
| 369 |
+
unsigned int hit = ((inp_byte != 0u) && (p >= cfg.conn_thr)) ? 1u : 0u;
|
| 370 |
+
local += hit;
|
| 371 |
+
}
|
| 372 |
+
unsigned int overlap = warp_sum_u32(local);
|
| 373 |
+
overlap = __shfl_sync(0xffffffffu, overlap, 0);
|
| 374 |
+
|
| 375 |
+
// Read boost + threshold for column c.
|
| 376 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 377 |
+
// Hopper: read from cluster-distributed shared memory.
|
| 378 |
+
const unsigned int owner_block = c / cols_per_block;
|
| 379 |
+
const unsigned int owner_offset = c - owner_block * cols_per_block;
|
| 380 |
+
float boost_val = cluster.map_shared_rank(s_boost, owner_block)[owner_offset];
|
| 381 |
+
float thr = cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset];
|
| 382 |
+
#else
|
| 383 |
+
// Pre-Hopper: read directly from global memory.
|
| 384 |
+
float boost_val = boost[c];
|
| 385 |
+
float thr = inhibition_threshold[c];
|
| 386 |
+
#endif
|
| 387 |
+
|
| 388 |
+
float boosted = (float)overlap * boost_val;
|
| 389 |
+
unsigned int is_active = (boosted > thr) ? 1u : 0u;
|
| 390 |
+
|
| 391 |
+
if (lane == 0) {
|
| 392 |
+
cols_out[col_base_out + c] = (unsigned char)is_active;
|
| 393 |
+
if (is_active) {
|
| 394 |
+
atomicAdd(&step_scratch[0], 1u);
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
// SP learn (Hebbian) on active columns.
|
| 399 |
+
// T9: use tile for input reads here too.
|
| 400 |
+
if (cfg.learn && is_active) {
|
| 401 |
+
for (unsigned int s = lane; s < S; s += 32u) {
|
| 402 |
+
unsigned int b = syn_bit[base + s];
|
| 403 |
+
float p = syn_perm[base + s];
|
| 404 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 405 |
+
unsigned int inp_byte = use_input_tile
|
| 406 |
+
? (unsigned int)s_input_tile[b]
|
| 407 |
+
: (unsigned int)inputs[inp_off + b];
|
| 408 |
+
#else
|
| 409 |
+
unsigned int inp_byte = (unsigned int)inputs[inp_off + b];
|
| 410 |
+
#endif
|
| 411 |
+
if (inp_byte != 0u) {
|
| 412 |
+
p += cfg.sp_inc;
|
| 413 |
+
if (p > 1.0f) p = 1.0f;
|
| 414 |
+
} else {
|
| 415 |
+
p -= cfg.sp_dec;
|
| 416 |
+
if (p < 0.0f) p = 0.0f;
|
| 417 |
+
}
|
| 418 |
+
syn_perm[base + s] = p;
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
// active_duty EMA + threshold adaptation.
|
| 423 |
+
// Writes go to both DSMEM (hot path, Hopper only) and GMEM (persistence).
|
| 424 |
+
if (lane == 0) {
|
| 425 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 426 |
+
float ad = cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset];
|
| 427 |
+
#else
|
| 428 |
+
float ad = active_duty[c];
|
| 429 |
+
#endif
|
| 430 |
+
float sample = is_active ? 1.0f : 0.0f;
|
| 431 |
+
ad = (1.0f - cfg.duty_alpha) * ad + cfg.duty_alpha * sample;
|
| 432 |
+
|
| 433 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 434 |
+
// Writeback: peer smem (for next timestep read) + GMEM (persistence).
|
| 435 |
+
cluster.map_shared_rank(s_active_duty, owner_block)[owner_offset] = ad;
|
| 436 |
+
#endif
|
| 437 |
+
active_duty[c] = ad;
|
| 438 |
+
|
| 439 |
+
// Threshold steers toward target sparsity.
|
| 440 |
+
float err = ad - cfg.sparsity_target;
|
| 441 |
+
float new_thr = thr + cfg.thr_adapt_rate * err * 100.0f;
|
| 442 |
+
if (new_thr < 0.1f) new_thr = 0.1f;
|
| 443 |
+
if (new_thr > 1000.0f) new_thr = 1000.0f;
|
| 444 |
+
|
| 445 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 446 |
+
// Writeback: peer smem (for next timestep read) + GMEM (persistence).
|
| 447 |
+
cluster.map_shared_rank(s_inhib_thr, owner_block)[owner_offset] = new_thr;
|
| 448 |
+
#endif
|
| 449 |
+
inhibition_threshold[c] = new_thr;
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
// ---- DSMEM WRITEBACK SYNC: peer-smem writes must be visible cluster-wide ----
|
| 454 |
+
//
|
| 455 |
+
// On Hopper: cluster.sync() ensures all peer smem writes from this
|
| 456 |
+
// timestep are visible to all blocks before Stage B / next t.
|
| 457 |
+
// On pre-Hopper: no smem peer writes occur (all state in GMEM),
|
| 458 |
+
// so no extra sync needed here — the grid barrier below suffices.
|
| 459 |
+
#if !defined(HTM_DISABLE_CLUSTER) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
| 460 |
+
cluster.sync();
|
| 461 |
+
#endif
|
| 462 |
+
|
| 463 |
+
// ---- BARRIER 2: SP active_mask must be visible before TM reads ----
|
| 464 |
+
// Fence: flush cols_out + active_duty + inhibition_threshold + step_scratch
|
| 465 |
+
// writes to global memory before peers advance past this barrier.
|
| 466 |
+
__threadfence();
|
| 467 |
+
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 468 |
+
|
| 469 |
+
// =========================================================
|
| 470 |
+
// STAGE B: Temporal Memory
|
| 471 |
+
// =========================================================
|
| 472 |
+
for (unsigned int c = col_lo; c < col_hi; c++) {
|
| 473 |
+
unsigned int col_active = cols_out[col_base_out + c];
|
| 474 |
+
if (col_active == 0u) continue;
|
| 475 |
+
|
| 476 |
+
unsigned int base_cell = c * cpc;
|
| 477 |
+
unsigned int any_predicted = 0u;
|
| 478 |
+
unsigned int best_seg_id_for_grow = 0xFFFFFFFFu;
|
| 479 |
+
unsigned int best_pot_count = 0u;
|
| 480 |
+
|
| 481 |
+
for (unsigned int k = 0u; k < cpc; k++) {
|
| 482 |
+
unsigned int cell = base_cell + k;
|
| 483 |
+
unsigned int n_segs_here = cell_seg_count[cell];
|
| 484 |
+
if (n_segs_here > MSC) n_segs_here = MSC;
|
| 485 |
+
if (n_segs_here == 0u) continue;
|
| 486 |
+
|
| 487 |
+
unsigned int seg_base_id = cell * MSC;
|
| 488 |
+
unsigned int cell_is_predictive = 0u;
|
| 489 |
+
|
| 490 |
+
for (unsigned int ls = 0u; ls < n_segs_here; ls++) {
|
| 491 |
+
unsigned int seg = seg_base_id + ls;
|
| 492 |
+
unsigned int n_syn = seg_syn_count[seg];
|
| 493 |
+
if (n_syn == 0u) continue;
|
| 494 |
+
unsigned int syn_base = seg * SPS;
|
| 495 |
+
|
| 496 |
+
unsigned int l_conn = 0u;
|
| 497 |
+
unsigned int l_pot = 0u;
|
| 498 |
+
for (unsigned int s = lane; s < n_syn; s += 32u) {
|
| 499 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 500 |
+
unsigned int w = prev_active[presyn >> 5];
|
| 501 |
+
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 502 |
+
if (bit) {
|
| 503 |
+
l_pot += 1u;
|
| 504 |
+
int p = (int)tm_syn_perm[syn_base + s];
|
| 505 |
+
if (p >= cfg.conn_thr_i16) l_conn += 1u;
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
unsigned int tot_conn = warp_sum_u32(l_conn);
|
| 509 |
+
unsigned int tot_pot = warp_sum_u32(l_pot);
|
| 510 |
+
tot_conn = __shfl_sync(0xffffffffu, tot_conn, 0);
|
| 511 |
+
tot_pot = __shfl_sync(0xffffffffu, tot_pot, 0);
|
| 512 |
+
|
| 513 |
+
if (tot_conn >= cfg.activation_threshold) cell_is_predictive = 1u;
|
| 514 |
+
if (tot_pot >= cfg.learning_threshold && tot_pot > best_pot_count) {
|
| 515 |
+
best_pot_count = tot_pot;
|
| 516 |
+
best_seg_id_for_grow = seg;
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
// Reinforce predicted-and-correct segment.
|
| 520 |
+
if (cfg.learn && tot_conn >= cfg.activation_threshold) {
|
| 521 |
+
for (unsigned int s = lane; s < n_syn; s += 32u) {
|
| 522 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 523 |
+
unsigned int w = prev_active[presyn >> 5];
|
| 524 |
+
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 525 |
+
int p = (int)tm_syn_perm[syn_base + s];
|
| 526 |
+
if (bit) {
|
| 527 |
+
int np = p + cfg.perm_inc_i16;
|
| 528 |
+
if (np > 32767) np = 32767;
|
| 529 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 530 |
+
} else {
|
| 531 |
+
int np = p - cfg.perm_dec_i16;
|
| 532 |
+
if (np < 0) np = 0;
|
| 533 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
if (cell_is_predictive) {
|
| 540 |
+
any_predicted = 1u;
|
| 541 |
+
if (lane == 0) {
|
| 542 |
+
unsigned int w = cell >> 5;
|
| 543 |
+
unsigned int m = 1u << (cell & 31u);
|
| 544 |
+
atomicOr(&curr_active[w], m);
|
| 545 |
+
atomicOr(&curr_winner[w], m);
|
| 546 |
+
}
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
// BURST if no predicted.
|
| 551 |
+
if (!any_predicted) {
|
| 552 |
+
if (lane == 0) {
|
| 553 |
+
for (unsigned int k = 0u; k < cpc; k++) {
|
| 554 |
+
unsigned int cell = base_cell + k;
|
| 555 |
+
unsigned int w = cell >> 5;
|
| 556 |
+
unsigned int m = 1u << (cell & 31u);
|
| 557 |
+
atomicOr(&curr_active[w], m);
|
| 558 |
+
}
|
| 559 |
+
unsigned int win = base_cell;
|
| 560 |
+
unsigned int ww = win >> 5;
|
| 561 |
+
unsigned int wm = 1u << (win & 31u);
|
| 562 |
+
atomicOr(&curr_winner[ww], wm);
|
| 563 |
+
atomicAdd(&step_scratch[1], 1u);
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
if (cfg.learn) {
|
| 567 |
+
unsigned int target_seg;
|
| 568 |
+
unsigned int existing_syn;
|
| 569 |
+
if (best_seg_id_for_grow != 0xFFFFFFFFu) {
|
| 570 |
+
// Reuse best matching segment.
|
| 571 |
+
target_seg = best_seg_id_for_grow;
|
| 572 |
+
existing_syn = seg_syn_count[target_seg];
|
| 573 |
+
target_seg = __shfl_sync(0xffffffffu, target_seg, 0);
|
| 574 |
+
existing_syn = __shfl_sync(0xffffffffu, existing_syn, 0);
|
| 575 |
+
|
| 576 |
+
// Reinforce its existing synapses.
|
| 577 |
+
unsigned int syn_base = target_seg * SPS;
|
| 578 |
+
for (unsigned int s = lane; s < existing_syn; s += 32u) {
|
| 579 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 580 |
+
unsigned int w = prev_active[presyn >> 5];
|
| 581 |
+
unsigned int bit = (w >> (presyn & 31u)) & 1u;
|
| 582 |
+
int p = (int)tm_syn_perm[syn_base + s];
|
| 583 |
+
if (bit) {
|
| 584 |
+
int np = p + cfg.perm_inc_i16;
|
| 585 |
+
if (np > 32767) np = 32767;
|
| 586 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 587 |
+
} else {
|
| 588 |
+
int np = p - cfg.perm_dec_i16;
|
| 589 |
+
if (np < 0) np = 0;
|
| 590 |
+
tm_syn_perm[syn_base + s] = (short)np;
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
} else {
|
| 594 |
+
// Allocate new segment on winner cell (cell 0 of col).
|
| 595 |
+
unsigned int new_seg = 0u;
|
| 596 |
+
if (lane == 0) {
|
| 597 |
+
unsigned int winner_cell = base_cell;
|
| 598 |
+
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
|
| 599 |
+
if (slot >= MSC) slot = slot % MSC;
|
| 600 |
+
new_seg = winner_cell * MSC + slot;
|
| 601 |
+
seg_cell_id[new_seg] = winner_cell;
|
| 602 |
+
seg_syn_count[new_seg] = 0u;
|
| 603 |
+
}
|
| 604 |
+
target_seg = __shfl_sync(0xffffffffu, new_seg, 0);
|
| 605 |
+
existing_syn = 0u;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
// Grow synapses to prev_winner cells — lane 0 serialized.
|
| 609 |
+
unsigned int room = (SPS > existing_syn) ? (SPS - existing_syn) : 0u;
|
| 610 |
+
unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
|
| 611 |
+
if (lane == 0 && max_grow > 0u) {
|
| 612 |
+
unsigned int syn_base = target_seg * SPS;
|
| 613 |
+
unsigned int grown = 0u;
|
| 614 |
+
unsigned int start_off = (c * 2654435761u + cfg.iter_seed + t) % cfg.bits_words;
|
| 615 |
+
for (unsigned int w_off = 0u;
|
| 616 |
+
w_off < cfg.bits_words && grown < max_grow;
|
| 617 |
+
w_off++) {
|
| 618 |
+
unsigned int widx = (start_off + w_off) % cfg.bits_words;
|
| 619 |
+
unsigned int word = prev_winner[widx];
|
| 620 |
+
while (word != 0u && grown < max_grow) {
|
| 621 |
+
unsigned int bit_pos = __ffs(word) - 1u;
|
| 622 |
+
word &= ~(1u << bit_pos);
|
| 623 |
+
unsigned int cell_id = widx * 32u + bit_pos;
|
| 624 |
+
if (cell_id >= cfg.n_cells) continue;
|
| 625 |
+
bool exists = false;
|
| 626 |
+
for (unsigned int es = 0u; es < existing_syn + grown; es++) {
|
| 627 |
+
if (syn_presyn[syn_base + es] == cell_id) { exists = true; break; }
|
| 628 |
+
}
|
| 629 |
+
if (exists) continue;
|
| 630 |
+
unsigned int write_idx = existing_syn + grown;
|
| 631 |
+
if (write_idx >= SPS) break;
|
| 632 |
+
syn_presyn[syn_base + write_idx] = cell_id;
|
| 633 |
+
tm_syn_perm[syn_base + write_idx] = (short)cfg.initial_perm_i16;
|
| 634 |
+
grown++;
|
| 635 |
+
}
|
| 636 |
+
}
|
| 637 |
+
if (grown > 0u) {
|
| 638 |
+
seg_syn_count[target_seg] = existing_syn + grown;
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
}
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
// ---- BARRIER 3: TM writes complete before anomaly + next-step read ----
|
| 646 |
+
// Fence: flush curr_active/curr_winner bitsets + tm_syn_perm +
|
| 647 |
+
// seg_syn_count + syn_presyn before peers advance and consume them as
|
| 648 |
+
// prev_active/prev_winner at t+1.
|
| 649 |
+
__threadfence();
|
| 650 |
+
fused_grid_barrier(grid, barrier_counters, 0u, phase++, cfg.cooperative_grid_sync);
|
| 651 |
+
|
| 652 |
+
// Write anomaly for step t.
|
| 653 |
+
if (blockIdx.x == 0u && tid == 0u) {
|
| 654 |
+
unsigned int total = step_scratch[0];
|
| 655 |
+
unsigned int bad = step_scratch[1];
|
| 656 |
+
float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
|
| 657 |
+
anom_out[t] = anom;
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
// Single-region kernel (legacy call site).
|
| 663 |
+
__global__ __launch_bounds__(256, 2)
|
| 664 |
+
void htm_fused_step(FusedPtrs P, FusedConfig cfg) {
|
| 665 |
+
htm_fused_step_body(P, cfg);
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
// Batched kernel: one cooperative launch for B regions. grid.y = B,
|
| 669 |
+
// grid.x = per-region block count. Each block reads its region's
|
| 670 |
+
// FusedPtrs from the device array via blockIdx.y.
|
| 671 |
+
__global__ __launch_bounds__(256, 2)
|
| 672 |
+
void htm_fused_step_batched(const FusedPtrs* __restrict__ P_arr, FusedConfig cfg) {
|
| 673 |
+
const FusedPtrs P = P_arr[blockIdx.y];
|
| 674 |
+
htm_fused_step_body(P, cfg);
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
} // extern "C"
|
overlay/htm_rust/src/gpu/tests.rs
CHANGED
|
@@ -1,643 +1,663 @@
|
|
| 1 |
-
//! Parity tests: GPU SP vs CPU SP reference.
|
| 2 |
-
//!
|
| 3 |
-
//! With matching seeds the two should produce bit-identical active-column sets
|
| 4 |
-
//! when `learn=false`, and remain bit-identical over repeated `learn=true`
|
| 5 |
-
//! steps because the Hebbian update is deterministic (no RNG once initialised).
|
| 6 |
-
//!
|
| 7 |
-
//! Run with: cargo test --release --features gpu
|
| 8 |
-
|
| 9 |
-
#![cfg(test)]
|
| 10 |
-
#![cfg(feature = "gpu")]
|
| 11 |
-
|
| 12 |
-
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 13 |
-
use crate::gpu::sp_gpu::SpatialPoolerGpu;
|
| 14 |
-
use crate::gpu::tm_gpu::TemporalMemoryGpu;
|
| 15 |
-
use crate::gpu::fused::{
|
| 16 |
-
launch_fused, plan_fused_launch, FusedState,
|
| 17 |
-
};
|
| 18 |
-
use cudarc::driver::CudaSlice;
|
| 19 |
-
use rand::{Rng, SeedableRng};
|
| 20 |
-
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 21 |
-
|
| 22 |
-
fn make_sdr(rng: &mut Xoshiro256PlusPlus, bits: usize, sparsity: f32) -> Vec<u8> {
|
| 23 |
-
let on = ((sparsity * bits as f32) as usize).max(1);
|
| 24 |
-
let mut v = vec![0u8; bits];
|
| 25 |
-
let mut placed = 0;
|
| 26 |
-
while placed < on {
|
| 27 |
-
let i = rng.gen_range(0..bits);
|
| 28 |
-
if v[i] == 0 {
|
| 29 |
-
v[i] = 1;
|
| 30 |
-
placed += 1;
|
| 31 |
-
}
|
| 32 |
-
}
|
| 33 |
-
v
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
#[test]
|
| 37 |
-
fn gpu_sp_matches_cpu_no_learn() {
|
| 38 |
-
let cfg = SpatialPoolerConfig::default();
|
| 39 |
-
let bits = cfg.input_bits;
|
| 40 |
-
let mut cpu = SpatialPooler::new(
|
| 41 |
-
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 42 |
-
1234,
|
| 43 |
-
);
|
| 44 |
-
let cpu_for_gpu = SpatialPooler::new(
|
| 45 |
-
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 46 |
-
1234,
|
| 47 |
-
);
|
| 48 |
-
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu)
|
| 49 |
-
.expect("gpu init (CUDA device available)");
|
| 50 |
-
gpu.set_strict_parity(true);
|
| 51 |
-
|
| 52 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
|
| 53 |
-
for step in 0..20 {
|
| 54 |
-
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 55 |
-
let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
|
| 56 |
-
|
| 57 |
-
let cpu_active: Vec<u32> = cpu.compute(&sdr_bool, false);
|
| 58 |
-
let gpu_active: Vec<u32> = gpu.compute(&sdr_u8, false).expect("gpu compute");
|
| 59 |
-
|
| 60 |
-
assert_eq!(
|
| 61 |
-
cpu_active, gpu_active,
|
| 62 |
-
"mismatch at step {step}: len cpu={} gpu={}",
|
| 63 |
-
cpu_active.len(), gpu_active.len()
|
| 64 |
-
);
|
| 65 |
-
}
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
#[test]
|
| 69 |
-
fn gpu_sp_matches_cpu_with_learn() {
|
| 70 |
-
let cfg = SpatialPoolerConfig::default();
|
| 71 |
-
let bits = cfg.input_bits;
|
| 72 |
-
let mut cpu = SpatialPooler::new(
|
| 73 |
-
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 74 |
-
5678,
|
| 75 |
-
);
|
| 76 |
-
let cpu_for_gpu = SpatialPooler::new(
|
| 77 |
-
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 78 |
-
5678,
|
| 79 |
-
);
|
| 80 |
-
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
|
| 81 |
-
gpu.set_strict_parity(true);
|
| 82 |
-
|
| 83 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
|
| 84 |
-
for step in 0..50 {
|
| 85 |
-
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 86 |
-
let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
|
| 87 |
-
|
| 88 |
-
let cpu_active = cpu.compute(&sdr_bool, true);
|
| 89 |
-
let gpu_active = gpu.compute(&sdr_u8, true).expect("gpu compute");
|
| 90 |
-
|
| 91 |
-
assert_eq!(
|
| 92 |
-
cpu_active, gpu_active,
|
| 93 |
-
"mismatch at step {step} with learning"
|
| 94 |
-
);
|
| 95 |
-
}
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
#[test]
|
| 99 |
-
fn gpu_tm_anomaly_decays_on_repeating_sequence() {
|
| 100 |
-
// End-to-end GPU pipeline: SP feeds TM; repeating SDR sequence should drive
|
| 101 |
-
// anomaly down over time.
|
| 102 |
-
use crate::gpu::HTMRegionGpu; // not pyclass methods; use internal constructor via Rust
|
| 103 |
-
// Easier: replicate the pipeline directly with SP + TM.
|
| 104 |
-
|
| 105 |
-
let cfg = SpatialPoolerConfig::default();
|
| 106 |
-
let bits = cfg.input_bits;
|
| 107 |
-
let n_cols = cfg.n_columns;
|
| 108 |
-
let cells_per_col = 32usize;
|
| 109 |
-
|
| 110 |
-
let cpu_for_gpu = SpatialPooler::new(SpatialPoolerConfig::default(), 314);
|
| 111 |
-
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
|
| 112 |
-
let dev = sp.dev_ref().clone();
|
| 113 |
-
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col)
|
| 114 |
-
.expect("gpu tm init");
|
| 115 |
-
tm.reset().expect("tm reset");
|
| 116 |
-
|
| 117 |
-
// Build 3 fixed SDRs, feed them in a repeating sequence.
|
| 118 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 119 |
-
let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
|
| 120 |
-
let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
|
| 121 |
-
|
| 122 |
-
// Warm up SP so columns are stable per symbol.
|
| 123 |
-
for _ in 0..100 {
|
| 124 |
-
for s in &seqs {
|
| 125 |
-
let _ = sp.compute(s, true).expect("sp compute");
|
| 126 |
-
}
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
// Build a long input buffer: 100 repetitions of [A,B,C] = 300 steps.
|
| 130 |
-
let repeats = 100usize;
|
| 131 |
-
let t = repeats * 3;
|
| 132 |
-
let mut inputs_flat = vec![0u8; t * bits];
|
| 133 |
-
for r in 0..repeats {
|
| 134 |
-
for (i, s) in seqs.iter().enumerate() {
|
| 135 |
-
let off = (r * 3 + i) * bits;
|
| 136 |
-
inputs_flat[off..off + bits].copy_from_slice(s);
|
| 137 |
-
}
|
| 138 |
-
}
|
| 139 |
-
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod");
|
| 140 |
-
|
| 141 |
-
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc cols");
|
| 142 |
-
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
|
| 143 |
-
|
| 144 |
-
sp.step_batch_with_tm(
|
| 145 |
-
&inputs_dev,
|
| 146 |
-
t,
|
| 147 |
-
bits,
|
| 148 |
-
true,
|
| 149 |
-
&mut cols_dev,
|
| 150 |
-
&mut anom_dev,
|
| 151 |
-
&mut tm,
|
| 152 |
-
).expect("step_batch_with_tm");
|
| 153 |
-
|
| 154 |
-
let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 155 |
-
let cols: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
|
| 156 |
-
|
| 157 |
-
// Active column count per step must equal k for every step.
|
| 158 |
-
let k = ((cfg.sparsity * n_cols as f32).round() as usize).max(1);
|
| 159 |
-
for ti in 0..t {
|
| 160 |
-
let step_slice = &cols[ti * n_cols..(ti + 1) * n_cols];
|
| 161 |
-
let n_on = step_slice.iter().filter(|&&b| b != 0).count();
|
| 162 |
-
assert_eq!(n_on, k, "step {ti} has {n_on} active cols, expected {k}");
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
// First repetition: anomaly should be near 1.0 (nothing predicted).
|
| 166 |
-
let early_avg: f32 = anom[3..9].iter().sum::<f32>() / 6.0;
|
| 167 |
-
// Last repetitions: anomaly should be noticeably lower.
|
| 168 |
-
let late_avg: f32 = anom[(t - 9)..t].iter().sum::<f32>() / 9.0;
|
| 169 |
-
eprintln!("gpu tm: early anomaly = {early_avg:.3}, late = {late_avg:.3}");
|
| 170 |
-
assert!(
|
| 171 |
-
late_avg < early_avg,
|
| 172 |
-
"GPU TM should reduce anomaly on repeating sequence: early={early_avg:.3}, late={late_avg:.3}"
|
| 173 |
-
);
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
/// Cluster-sync smoke test: verifies that the fused megakernel (which relies on
|
| 177 |
-
/// hardware `cluster::sync()` / grid-barrier on H100/H200 Hopper) completes
|
| 178 |
-
/// without deadlock when called with real HTM state, and that output shapes are
|
| 179 |
-
/// sane (no NaN / Inf in anomaly scores, active-column count in plausible range).
|
| 180 |
-
///
|
| 181 |
-
/// This is an *integration* test, not a synthetic micro-benchmark: it exercises
|
| 182 |
-
/// exactly the same `launch_fused` code path used in production, so any
|
| 183 |
-
/// deadlock in the cooperative-grid or DLB barrier would surface here.
|
| 184 |
-
///
|
| 185 |
-
/// Skips gracefully (with an eprintln) when no GPU is available — the test
|
| 186 |
-
/// binary returns exit-code 0 in that case so CI still passes.
|
| 187 |
-
#[test]
|
| 188 |
-
fn cluster_sync_smoke_test() {
|
| 189 |
-
// Build a tiny HTM region (1024 inputs, 256 columns, 4 cells/column).
|
| 190 |
-
// This keeps VRAM usage minimal while still exercising all kernel paths.
|
| 191 |
-
let input_bits = 1024usize;
|
| 192 |
-
let n_columns = 256usize;
|
| 193 |
-
let cells_per_col = 4usize;
|
| 194 |
-
|
| 195 |
-
// Probe cooperative launch attribute before doing any real work.
|
| 196 |
-
// CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 223 (added in CUDA 11.8 for Hopper).
|
| 197 |
-
// cudarc exposes raw attribute querying; we check cooperative launch (98)
|
| 198 |
-
// as the guard — cluster launch is a superset and not separately probed
|
| 199 |
-
// here since cudarc doesn't expose attribute 223 symbolically yet.
|
| 200 |
-
// On pre-Hopper hardware the DLB barrier path is used instead and the
|
| 201 |
-
// test still validates no deadlock on that path.
|
| 202 |
-
|
| 203 |
-
let make_cfg = || SpatialPoolerConfig {
|
| 204 |
-
input_bits,
|
| 205 |
-
n_columns,
|
| 206 |
-
sparsity: 0.04, // ~10 active cols out of 256
|
| 207 |
-
..SpatialPoolerConfig::default()
|
| 208 |
-
};
|
| 209 |
-
|
| 210 |
-
let cpu_ref = SpatialPooler::new(make_cfg(), 42);
|
| 211 |
-
|
| 212 |
-
let mut sp = match SpatialPoolerGpu::from_cpu(&cpu_ref) {
|
| 213 |
-
Ok(sp) => sp,
|
| 214 |
-
Err(e) => {
|
| 215 |
-
eprintln!("[cluster_sync_smoke_test] No GPU available ({e:?}) — skipping");
|
| 216 |
-
return;
|
| 217 |
-
}
|
| 218 |
-
};
|
| 219 |
-
|
| 220 |
-
let dev = sp.dev_ref().clone();
|
| 221 |
-
|
| 222 |
-
// Check cooperative launch support; skip with a clear message if absent.
|
| 223 |
-
let cooperative_ok = matches!(
|
| 224 |
-
dev.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
|
| 225 |
-
Ok(v) if v > 0
|
| 226 |
-
);
|
| 227 |
-
if !cooperative_ok {
|
| 228 |
-
eprintln!("[cluster_sync_smoke_test] CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH=0 — DLB path only, still running test");
|
| 229 |
-
// We continue — the DLB path is the production fallback and must not deadlock either.
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
let mut tm = match TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_col) {
|
| 233 |
-
Ok(tm) => tm,
|
| 234 |
-
Err(e) => {
|
| 235 |
-
eprintln!("[cluster_sync_smoke_test] TemporalMemoryGpu::new failed ({e:?}) — skipping");
|
| 236 |
-
return;
|
| 237 |
-
}
|
| 238 |
-
};
|
| 239 |
-
tm.reset().expect("tm reset");
|
| 240 |
-
|
| 241 |
-
let mut fused_st: FusedState = match FusedState::new(
|
| 242 |
-
dev.clone(),
|
| 243 |
-
n_columns,
|
| 244 |
-
cells_per_col,
|
| 245 |
-
sp.initial_threshold_estimate(),
|
| 246 |
-
) {
|
| 247 |
-
Ok(f) => f,
|
| 248 |
-
Err(e) => {
|
| 249 |
-
eprintln!("[cluster_sync_smoke_test] FusedState::new failed ({e:?}) — skipping");
|
| 250 |
-
return;
|
| 251 |
-
}
|
| 252 |
-
};
|
| 253 |
-
fused_st.reset().expect("fused reset");
|
| 254 |
-
|
| 255 |
-
// Build T=4 timesteps of all-zero input SDRs.
|
| 256 |
-
let t = 4usize;
|
| 257 |
-
let inputs_flat = vec![0u8; t * input_bits];
|
| 258 |
-
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod inputs");
|
| 259 |
-
|
| 260 |
-
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_columns).expect("alloc cols");
|
| 261 |
-
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
|
| 262 |
-
|
| 263 |
-
// Execute with a 2-second timeout guard via a thread. If the kernel
|
| 264 |
-
// deadlocks, the parent test process times out and the CI job reports
|
| 265 |
-
// failure — we can't cancel a live CUDA kernel from Rust, but the
|
| 266 |
-
// launch_fused call itself must return within this window on any sane GPU.
|
| 267 |
-
//
|
| 268 |
-
// We run the kernel inline (not in a separate thread) because CUDA contexts
|
| 269 |
-
// are not safely shareable across threads without explicit multi-threading
|
| 270 |
-
// setup. The 2-second bound is enforced implicitly: if the kernel deadlocks,
|
| 271 |
-
// the test binary will hang and the CI timeout (typically 5 min) will kill it.
|
| 272 |
-
// For local dev, the deadlock would be immediately obvious.
|
| 273 |
-
|
| 274 |
-
launch_fused(
|
| 275 |
-
&mut sp,
|
| 276 |
-
&mut tm,
|
| 277 |
-
&mut fused_st,
|
| 278 |
-
&inputs_dev,
|
| 279 |
-
&mut cols_dev,
|
| 280 |
-
&mut anom_dev,
|
| 281 |
-
t,
|
| 282 |
-
input_bits,
|
| 283 |
-
false, // learn=false for determinism
|
| 284 |
-
).expect("launch_fused (cluster_sync_smoke_test): deadlock or CUDA error");
|
| 285 |
-
|
| 286 |
-
dev.synchronize().expect("device sync after launch_fused");
|
| 287 |
-
|
| 288 |
-
// --- Correctness assertions ---
|
| 289 |
-
|
| 290 |
-
let cols_host: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
|
| 291 |
-
let anom_host: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 292 |
-
|
| 293 |
-
// Output buffers must be exactly the right size.
|
| 294 |
-
assert_eq!(cols_host.len(), t * n_columns, "cols buffer size mismatch");
|
| 295 |
-
assert_eq!(anom_host.len(), t, "anom buffer size mismatch");
|
| 296 |
-
|
| 297 |
-
// Anomaly scores must be finite (NaN/Inf indicates numerical blow-up).
|
| 298 |
-
for (i, &a) in anom_host.iter().enumerate() {
|
| 299 |
-
assert!(a.is_finite(), "anomaly[{i}] is not finite: {a}");
|
| 300 |
-
assert!(a >= 0.0 && a <= 1.0, "anomaly[{i}] out of [0,1]: {a}");
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
// Active-column count per step: threshold-based inhibition, so 0 is
|
| 304 |
-
// possible on cold start (before thresholds calibrate), but we assert
|
| 305 |
-
// <= n_columns to catch buffer overruns or completely wrong output.
|
| 306 |
-
for ti in 0..t {
|
| 307 |
-
let n_on = cols_host[ti * n_columns..(ti + 1) * n_columns]
|
| 308 |
-
.iter()
|
| 309 |
-
.filter(|&&b| b != 0)
|
| 310 |
-
.count();
|
| 311 |
-
assert!(
|
| 312 |
-
n_on <= n_columns,
|
| 313 |
-
"step {ti}: active columns {n_on} > n_columns {n_columns} (buffer overrun?)"
|
| 314 |
-
);
|
| 315 |
-
}
|
| 316 |
-
|
| 317 |
-
eprintln!(
|
| 318 |
-
"[cluster_sync_smoke_test] PASSED: T={t}, n_cols={n_columns}, \
|
| 319 |
-
input_bits={input_bits}, cooperative_supported={cooperative_ok}, \
|
| 320 |
-
anom={anom_host:?}"
|
| 321 |
-
);
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
-
/// Parity check: the CAI zero-copy path (`step_many_cuda`) must produce
|
| 325 |
-
/// bit-identical outputs to the numpy H2D/D2H path (`step_batch_with_tm`),
|
| 326 |
-
/// since the kernel pipeline is the same — only the I/O wrapping changes.
|
| 327 |
-
/// We skip the PyO3 CAI dict plumbing here and test the underlying
|
| 328 |
-
/// ManuallyDrop + upgrade_device_ptr pattern directly.
|
| 329 |
-
#[test]
|
| 330 |
-
fn gpu_cuda_vs_numpy_parity() {
|
| 331 |
-
use std::mem::ManuallyDrop;
|
| 332 |
-
|
| 333 |
-
let cfg = SpatialPoolerConfig::default();
|
| 334 |
-
let bits = cfg.input_bits;
|
| 335 |
-
let n_cols = cfg.n_columns;
|
| 336 |
-
let cells_per_col = 32usize;
|
| 337 |
-
|
| 338 |
-
// Build two identical (SP, TM) pairs from the same seed.
|
| 339 |
-
let build = || -> (SpatialPoolerGpu, TemporalMemoryGpu) {
|
| 340 |
-
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271828);
|
| 341 |
-
let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu init");
|
| 342 |
-
let dev = sp.dev_ref().clone();
|
| 343 |
-
let mut tm = TemporalMemoryGpu::new(dev, n_cols, cells_per_col).expect("tm init");
|
| 344 |
-
tm.reset().expect("tm reset");
|
| 345 |
-
(sp, tm)
|
| 346 |
-
};
|
| 347 |
-
|
| 348 |
-
// Deterministic SDR sequence.
|
| 349 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
|
| 350 |
-
let t = 32usize;
|
| 351 |
-
let mut inputs_flat = vec![0u8; t * bits];
|
| 352 |
-
for i in 0..t {
|
| 353 |
-
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 354 |
-
inputs_flat[i * bits..(i + 1) * bits].copy_from_slice(&sdr);
|
| 355 |
-
}
|
| 356 |
-
|
| 357 |
-
// ---- Path A: owned CudaSlice (numpy-equivalent path) ----
|
| 358 |
-
let (mut sp_a, mut tm_a) = build();
|
| 359 |
-
let dev_a = sp_a.dev_ref().clone();
|
| 360 |
-
let inputs_a: CudaSlice<u8> = dev_a.htod_sync_copy(&inputs_flat).expect("htod");
|
| 361 |
-
let mut cols_a = dev_a.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_a");
|
| 362 |
-
let mut anom_a = dev_a.alloc_zeros::<f32>(t).expect("alloc anom_a");
|
| 363 |
-
sp_a.step_batch_with_tm(&inputs_a, t, bits, false, &mut cols_a, &mut anom_a, &mut tm_a)
|
| 364 |
-
.expect("owned step_batch_with_tm");
|
| 365 |
-
dev_a.synchronize().expect("sync a");
|
| 366 |
-
let cols_a_host: Vec<u8> = dev_a.dtoh_sync_copy(&cols_a).expect("d2h cols_a");
|
| 367 |
-
let anom_a_host: Vec<f32> = dev_a.dtoh_sync_copy(&anom_a).expect("d2h anom_a");
|
| 368 |
-
|
| 369 |
-
// ---- Path B: borrowed device pointers via upgrade_device_ptr ----
|
| 370 |
-
// We allocate fresh owned CudaSlices on a fresh device, then take their
|
| 371 |
-
// raw ptrs and re-wrap as ManuallyDrop borrowed views — mimicking what
|
| 372 |
-
// `step_many_cuda` does with torch-owned CUDA memory.
|
| 373 |
-
let (mut sp_b, mut tm_b) = build();
|
| 374 |
-
let dev_b = sp_b.dev_ref().clone();
|
| 375 |
-
let inputs_b_owned: CudaSlice<u8> = dev_b.htod_sync_copy(&inputs_flat).expect("htod");
|
| 376 |
-
let cols_b_owned = dev_b.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_b");
|
| 377 |
-
let anom_b_owned = dev_b.alloc_zeros::<f32>(t).expect("alloc anom_b");
|
| 378 |
-
|
| 379 |
-
// Extract raw CUdeviceptrs (and leak the owners so their Drop doesn't free).
|
| 380 |
-
let inputs_ptr = inputs_b_owned.leak();
|
| 381 |
-
let cols_ptr = cols_b_owned.leak();
|
| 382 |
-
let anom_ptr = anom_b_owned.leak();
|
| 383 |
-
|
| 384 |
-
// Re-wrap as borrowed views.
|
| 385 |
-
let inputs_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) });
|
| 386 |
-
let mut cols_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) });
|
| 387 |
-
let mut anom_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) });
|
| 388 |
-
|
| 389 |
-
sp_b.step_batch_with_tm(&inputs_b, t, bits, false, &mut cols_b, &mut anom_b, &mut tm_b)
|
| 390 |
-
.expect("borrowed step_batch_with_tm");
|
| 391 |
-
dev_b.synchronize().expect("sync b");
|
| 392 |
-
// `ManuallyDrop` doesn't auto-coerce to `&CudaSlice<T>` for the DevicePtr
|
| 393 |
-
// trait bound on `dtoh_sync_copy`; explicit deref.
|
| 394 |
-
let cols_b_host: Vec<u8> = dev_b.dtoh_sync_copy(&*cols_b).expect("d2h cols_b");
|
| 395 |
-
let anom_b_host: Vec<f32> = dev_b.dtoh_sync_copy(&*anom_b).expect("d2h anom_b");
|
| 396 |
-
|
| 397 |
-
// Re-own so Drop actually frees (we leaked above).
|
| 398 |
-
let _inputs_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) };
|
| 399 |
-
let _cols_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) };
|
| 400 |
-
let _anom_owned_again = unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) };
|
| 401 |
-
|
| 402 |
-
assert_eq!(cols_a_host, cols_b_host, "active-column mask diverges between numpy and CAI paths");
|
| 403 |
-
assert_eq!(anom_a_host.len(), anom_b_host.len());
|
| 404 |
-
for (i, (a, b)) in anom_a_host.iter().zip(anom_b_host.iter()).enumerate() {
|
| 405 |
-
// Anomaly is a pure division of integer counts — bit-exact expected.
|
| 406 |
-
assert!((a - b).abs() < 1e-7, "anomaly mismatch at step {i}: a={a} b={b}");
|
| 407 |
-
}
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
/// Fused kernel: threshold activation should converge to near target sparsity
|
| 411 |
-
/// after a short warmup. Acceptance: mean activation rate per step lands in
|
| 412 |
-
/// [0.3*target, 2.5*target] after 500-step warmup. Because the threshold
|
| 413 |
-
/// starts conservative (=2.0) and the per-column adaptation rate is slow
|
| 414 |
-
/// (0.001), we allow a generous band — the test asserts directional
|
| 415 |
-
/// convergence toward the target, not tight matching.
|
| 416 |
-
#[test]
|
| 417 |
-
fn gpu_threshold_converges_to_sparsity() {
|
| 418 |
-
let cfg = SpatialPoolerConfig::default();
|
| 419 |
-
let bits = cfg.input_bits;
|
| 420 |
-
let n_cols = cfg.n_columns;
|
| 421 |
-
let cells_per_col = 32usize;
|
| 422 |
-
let target = cfg.sparsity; // 0.02 = 40 cols expected
|
| 423 |
-
|
| 424 |
-
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 111);
|
| 425 |
-
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 426 |
-
let dev = sp.dev_ref().clone();
|
| 427 |
-
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
|
| 428 |
-
let mut fused = FusedState::new(
|
| 429 |
-
dev.clone(),
|
| 430 |
-
n_cols,
|
| 431 |
-
cells_per_col,
|
| 432 |
-
sp.initial_threshold_estimate(),
|
| 433 |
-
).expect("fused init");
|
| 434 |
-
tm.reset().expect("tm reset");
|
| 435 |
-
fused.reset().expect("fused reset");
|
| 436 |
-
|
| 437 |
-
// Warmup: 1000 random 2%-sparse SDRs.
|
| 438 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
|
| 439 |
-
let t_warm = 1000usize;
|
| 440 |
-
let mut inputs = vec![0u8; t_warm * bits];
|
| 441 |
-
for ti in 0..t_warm {
|
| 442 |
-
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 443 |
-
inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 444 |
-
}
|
| 445 |
-
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod");
|
| 446 |
-
let mut cols_dev = dev.alloc_zeros::<u8>(t_warm * n_cols).expect("alloc cols");
|
| 447 |
-
let mut anom_dev = dev.alloc_zeros::<f32>(t_warm).expect("alloc anom");
|
| 448 |
-
launch_fused(
|
| 449 |
-
&mut sp, &mut tm, &mut fused,
|
| 450 |
-
&inputs_dev, &mut cols_dev, &mut anom_dev,
|
| 451 |
-
t_warm, bits, true,
|
| 452 |
-
).expect("warmup launch");
|
| 453 |
-
dev.synchronize().expect("sync");
|
| 454 |
-
|
| 455 |
-
// Measurement pass: another 200 steps, measure mean activation.
|
| 456 |
-
let t_meas = 200usize;
|
| 457 |
-
let mut meas_inputs = vec![0u8; t_meas * bits];
|
| 458 |
-
for ti in 0..t_meas {
|
| 459 |
-
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 460 |
-
meas_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 461 |
-
}
|
| 462 |
-
let meas_dev: CudaSlice<u8> = dev.htod_sync_copy(&meas_inputs).expect("htod meas");
|
| 463 |
-
let mut meas_cols = dev.alloc_zeros::<u8>(t_meas * n_cols).expect("alloc meas cols");
|
| 464 |
-
let mut meas_anom = dev.alloc_zeros::<f32>(t_meas).expect("alloc meas anom");
|
| 465 |
-
launch_fused(
|
| 466 |
-
&mut sp, &mut tm, &mut fused,
|
| 467 |
-
&meas_dev, &mut meas_cols, &mut meas_anom,
|
| 468 |
-
t_meas, bits, true,
|
| 469 |
-
).expect("meas launch");
|
| 470 |
-
dev.synchronize().expect("sync meas");
|
| 471 |
-
|
| 472 |
-
let cols_host: Vec<u8> = dev.dtoh_sync_copy(&meas_cols).expect("d2h");
|
| 473 |
-
let mut step_counts = Vec::with_capacity(t_meas);
|
| 474 |
-
for ti in 0..t_meas {
|
| 475 |
-
let n_on = cols_host[ti*n_cols..(ti+1)*n_cols]
|
| 476 |
-
.iter().filter(|&&b| b != 0).count();
|
| 477 |
-
step_counts.push(n_on);
|
| 478 |
-
}
|
| 479 |
-
let mean_active: f64 = step_counts.iter().map(|&c| c as f64).sum::<f64>()
|
| 480 |
-
/ (t_meas as f64);
|
| 481 |
-
let target_active = target as f64 * n_cols as f64;
|
| 482 |
-
eprintln!(
|
| 483 |
-
"threshold-activation convergence: mean_active/step = {mean_active:.1} \
|
| 484 |
-
(target = {target_active:.1})"
|
| 485 |
-
);
|
| 486 |
-
// Very generous band — we just want to confirm the threshold loop is
|
| 487 |
-
// functioning (not diverged to 0 or to all-active).
|
| 488 |
-
assert!(
|
| 489 |
-
mean_active >= 0.25 * target_active && mean_active <= 4.0 * target_active,
|
| 490 |
-
"mean active {mean_active:.1} outside [0.25x, 4x] of target {target_active:.1}"
|
| 491 |
-
);
|
| 492 |
-
}
|
| 493 |
-
|
| 494 |
-
/// Fused kernel: TM should learn a repeating sequence — anomaly decays.
|
| 495 |
-
#[test]
|
| 496 |
-
fn gpu_fused_tm_anomaly_decays_on_repeating_sequence() {
|
| 497 |
-
let cfg = SpatialPoolerConfig::default();
|
| 498 |
-
let bits = cfg.input_bits;
|
| 499 |
-
let n_cols = cfg.n_columns;
|
| 500 |
-
let cells_per_col = 32usize;
|
| 501 |
-
|
| 502 |
-
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271);
|
| 503 |
-
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 504 |
-
let dev = sp.dev_ref().clone();
|
| 505 |
-
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
|
| 506 |
-
let mut fused = FusedState::new(
|
| 507 |
-
dev.clone(),
|
| 508 |
-
n_cols,
|
| 509 |
-
cells_per_col,
|
| 510 |
-
sp.initial_threshold_estimate(),
|
| 511 |
-
).expect("fused init");
|
| 512 |
-
tm.reset().expect("tm reset");
|
| 513 |
-
fused.reset().expect("fused reset");
|
| 514 |
-
|
| 515 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 516 |
-
let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
|
| 517 |
-
let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
|
| 518 |
-
|
| 519 |
-
// Warmup SP threshold calibration with random SDRs first.
|
| 520 |
-
let warm = 300usize;
|
| 521 |
-
let mut warm_inputs = vec![0u8; warm * bits];
|
| 522 |
-
for ti in 0..warm {
|
| 523 |
-
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 524 |
-
warm_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 525 |
-
}
|
| 526 |
-
let warm_dev: CudaSlice<u8> = dev.htod_sync_copy(&warm_inputs).expect("htod warm");
|
| 527 |
-
let mut warm_cols = dev.alloc_zeros::<u8>(warm * n_cols).expect("alloc warm cols");
|
| 528 |
-
let mut warm_anom = dev.alloc_zeros::<f32>(warm).expect("alloc warm anom");
|
| 529 |
-
launch_fused(
|
| 530 |
-
&mut sp, &mut tm, &mut fused,
|
| 531 |
-
&warm_dev, &mut warm_cols, &mut warm_anom,
|
| 532 |
-
warm, bits, true,
|
| 533 |
-
).expect("warm launch");
|
| 534 |
-
dev.synchronize().expect("sync warm");
|
| 535 |
-
|
| 536 |
-
// Feed repeating A,B,C sequence for 100 reps.
|
| 537 |
-
let repeats = 100usize;
|
| 538 |
-
let t = repeats * 3;
|
| 539 |
-
let mut inputs = vec![0u8; t * bits];
|
| 540 |
-
for r in 0..repeats {
|
| 541 |
-
for (i, s) in seqs.iter().enumerate() {
|
| 542 |
-
let off = (r*3 + i) * bits;
|
| 543 |
-
inputs[off..off+bits].copy_from_slice(s);
|
| 544 |
-
}
|
| 545 |
-
}
|
| 546 |
-
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod rep");
|
| 547 |
-
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc rep cols");
|
| 548 |
-
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc rep anom");
|
| 549 |
-
launch_fused(
|
| 550 |
-
&mut sp, &mut tm, &mut fused,
|
| 551 |
-
&inputs_dev, &mut cols_dev, &mut anom_dev,
|
| 552 |
-
t, bits, true,
|
| 553 |
-
).expect("rep launch");
|
| 554 |
-
dev.synchronize().expect("sync rep");
|
| 555 |
-
|
| 556 |
-
let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 557 |
-
let early_avg: f32 = anom[3..12].iter().sum::<f32>() / 9.0;
|
| 558 |
-
let late_avg: f32 = anom[(t-9)..t].iter().sum::<f32>() / 9.0;
|
| 559 |
-
eprintln!("fused TM anomaly: early={early_avg:.3} late={late_avg:.3}");
|
| 560 |
-
assert!(
|
| 561 |
-
late_avg < early_avg,
|
| 562 |
-
"anomaly must decay: early={early_avg:.3} late={late_avg:.3}"
|
| 563 |
-
);
|
| 564 |
-
assert!(
|
| 565 |
-
late_avg < 0.5,
|
| 566 |
-
"late anomaly must be < 0.5 (got {late_avg:.3})"
|
| 567 |
-
);
|
| 568 |
-
}
|
| 569 |
-
|
| 570 |
-
#[test]
|
| 571 |
-
fn gpu_sp_yields_k_winners() {
|
| 572 |
-
let cfg = SpatialPoolerConfig::default();
|
| 573 |
-
let bits = cfg.input_bits;
|
| 574 |
-
let n = cfg.n_columns;
|
| 575 |
-
let expected_k = ((cfg.sparsity * n as f32).round() as usize).max(1);
|
| 576 |
-
let cpu = SpatialPooler::new(SpatialPoolerConfig::default(), 7);
|
| 577 |
-
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu).expect("gpu init");
|
| 578 |
-
|
| 579 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
|
| 580 |
-
for _ in 0..10 {
|
| 581 |
-
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 582 |
-
let active = gpu.compute(&sdr_u8, false).expect("gpu compute");
|
| 583 |
-
assert_eq!(active.len(), expected_k);
|
| 584 |
-
// Ensure sorted + unique.
|
| 585 |
-
for w in active.windows(2) {
|
| 586 |
-
assert!(w[0] < w[1], "duplicate or out-of-order winner indices");
|
| 587 |
-
}
|
| 588 |
-
}
|
| 589 |
-
}
|
| 590 |
-
|
| 591 |
-
#[test]
|
| 592 |
-
fn fused_launch_plan_uses_cooperative_grid_sync() {
|
| 593 |
-
let plan = plan_fused_launch(30, true, 30, None).expect("cooperative supported");
|
| 594 |
-
assert_eq!(plan.grid_dim_x, 16);
|
| 595 |
-
assert_eq!(plan.cooperative_grid_limit, 30);
|
| 596 |
-
}
|
| 597 |
-
|
| 598 |
-
#[test]
|
| 599 |
-
fn fused_launch_plan_scales_to_big_gpu() {
|
| 600 |
-
// H200-like: 132 SMs, high cooperative_grid_limit. Cap still applies.
|
| 601 |
-
let plan = plan_fused_launch(132, true, 1000, None).expect("cooperative supported");
|
| 602 |
-
assert_eq!(plan.grid_dim_x, 16); // capped by default override
|
| 603 |
-
let plan = plan_fused_launch(132, true, 1000, Some(64)).expect("cooperative supported");
|
| 604 |
-
assert_eq!(plan.grid_dim_x, 64); // override raises the cap
|
| 605 |
-
}
|
| 606 |
-
|
| 607 |
-
#[test]
|
| 608 |
-
fn fused_launch_plan_refuses_non_cooperative_devices() {
|
| 609 |
-
// The slow path was removed. Devices without cooperative launch fail fast.
|
| 610 |
-
let err = plan_fused_launch(30, false, 0, None).unwrap_err();
|
| 611 |
-
assert!(err.contains("cooperative launch"));
|
| 612 |
-
}
|
| 613 |
-
|
| 614 |
-
#[test]
|
| 615 |
-
fn fused_grid_cap_env_override_is_honored() {
|
| 616 |
-
let cfg = SpatialPoolerConfig::default();
|
| 617 |
-
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 5252);
|
| 618 |
-
let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 619 |
-
let dev = sp.dev_ref().clone();
|
| 620 |
-
|
| 621 |
-
unsafe { std::env::set_var("HTM_FUSED_GRID_CAP", "12"); }
|
| 622 |
-
let fused = FusedState::new(
|
| 623 |
-
dev.clone(),
|
| 624 |
-
cfg.n_columns,
|
| 625 |
-
32usize,
|
| 626 |
-
sp.initial_threshold_estimate(),
|
| 627 |
-
).expect("fused init");
|
| 628 |
-
unsafe { std::env::remove_var("HTM_FUSED_GRID_CAP"); }
|
| 629 |
-
|
| 630 |
-
let sm_count = match dev.attribute(
|
| 631 |
-
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
| 632 |
-
) {
|
| 633 |
-
Ok(v) => v as u32,
|
| 634 |
-
Err(_) => 16u32,
|
| 635 |
-
};
|
| 636 |
-
let expected = sm_count.max(1).min(12);
|
| 637 |
-
assert_eq!(
|
| 638 |
-
fused.grid_dim_x,
|
| 639 |
-
expected,
|
| 640 |
-
"fused grid cap env override ignored: expected min(sm_count, 12) = {expected}, got {}",
|
| 641 |
-
fused.grid_dim_x,
|
| 642 |
-
);
|
| 643 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Parity tests: GPU SP vs CPU SP reference.
|
| 2 |
+
//!
|
| 3 |
+
//! With matching seeds the two should produce bit-identical active-column sets
|
| 4 |
+
//! when `learn=false`, and remain bit-identical over repeated `learn=true`
|
| 5 |
+
//! steps because the Hebbian update is deterministic (no RNG once initialised).
|
| 6 |
+
//!
|
| 7 |
+
//! Run with: cargo test --release --features gpu
|
| 8 |
+
|
| 9 |
+
#![cfg(test)]
|
| 10 |
+
#![cfg(feature = "gpu")]
|
| 11 |
+
|
| 12 |
+
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 13 |
+
use crate::gpu::sp_gpu::SpatialPoolerGpu;
|
| 14 |
+
use crate::gpu::tm_gpu::TemporalMemoryGpu;
|
| 15 |
+
use crate::gpu::fused::{
|
| 16 |
+
launch_fused, plan_batched_grid_dim, plan_fused_launch, FusedState,
|
| 17 |
+
};
|
| 18 |
+
use cudarc::driver::CudaSlice;
|
| 19 |
+
use rand::{Rng, SeedableRng};
|
| 20 |
+
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 21 |
+
|
| 22 |
+
fn make_sdr(rng: &mut Xoshiro256PlusPlus, bits: usize, sparsity: f32) -> Vec<u8> {
|
| 23 |
+
let on = ((sparsity * bits as f32) as usize).max(1);
|
| 24 |
+
let mut v = vec![0u8; bits];
|
| 25 |
+
let mut placed = 0;
|
| 26 |
+
while placed < on {
|
| 27 |
+
let i = rng.gen_range(0..bits);
|
| 28 |
+
if v[i] == 0 {
|
| 29 |
+
v[i] = 1;
|
| 30 |
+
placed += 1;
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
v
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
#[test]
|
| 37 |
+
fn gpu_sp_matches_cpu_no_learn() {
|
| 38 |
+
let cfg = SpatialPoolerConfig::default();
|
| 39 |
+
let bits = cfg.input_bits;
|
| 40 |
+
let mut cpu = SpatialPooler::new(
|
| 41 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 42 |
+
1234,
|
| 43 |
+
);
|
| 44 |
+
let cpu_for_gpu = SpatialPooler::new(
|
| 45 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 46 |
+
1234,
|
| 47 |
+
);
|
| 48 |
+
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu)
|
| 49 |
+
.expect("gpu init (CUDA device available)");
|
| 50 |
+
gpu.set_strict_parity(true);
|
| 51 |
+
|
| 52 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
|
| 53 |
+
for step in 0..20 {
|
| 54 |
+
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 55 |
+
let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
|
| 56 |
+
|
| 57 |
+
let cpu_active: Vec<u32> = cpu.compute(&sdr_bool, false);
|
| 58 |
+
let gpu_active: Vec<u32> = gpu.compute(&sdr_u8, false).expect("gpu compute");
|
| 59 |
+
|
| 60 |
+
assert_eq!(
|
| 61 |
+
cpu_active, gpu_active,
|
| 62 |
+
"mismatch at step {step}: len cpu={} gpu={}",
|
| 63 |
+
cpu_active.len(), gpu_active.len()
|
| 64 |
+
);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
#[test]
|
| 69 |
+
fn gpu_sp_matches_cpu_with_learn() {
|
| 70 |
+
let cfg = SpatialPoolerConfig::default();
|
| 71 |
+
let bits = cfg.input_bits;
|
| 72 |
+
let mut cpu = SpatialPooler::new(
|
| 73 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 74 |
+
5678,
|
| 75 |
+
);
|
| 76 |
+
let cpu_for_gpu = SpatialPooler::new(
|
| 77 |
+
SpatialPoolerConfig { ..SpatialPoolerConfig::default() },
|
| 78 |
+
5678,
|
| 79 |
+
);
|
| 80 |
+
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
|
| 81 |
+
gpu.set_strict_parity(true);
|
| 82 |
+
|
| 83 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(42);
|
| 84 |
+
for step in 0..50 {
|
| 85 |
+
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 86 |
+
let sdr_bool: Vec<bool> = sdr_u8.iter().map(|&x| x != 0).collect();
|
| 87 |
+
|
| 88 |
+
let cpu_active = cpu.compute(&sdr_bool, true);
|
| 89 |
+
let gpu_active = gpu.compute(&sdr_u8, true).expect("gpu compute");
|
| 90 |
+
|
| 91 |
+
assert_eq!(
|
| 92 |
+
cpu_active, gpu_active,
|
| 93 |
+
"mismatch at step {step} with learning"
|
| 94 |
+
);
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
#[test]
|
| 99 |
+
fn gpu_tm_anomaly_decays_on_repeating_sequence() {
|
| 100 |
+
// End-to-end GPU pipeline: SP feeds TM; repeating SDR sequence should drive
|
| 101 |
+
// anomaly down over time.
|
| 102 |
+
use crate::gpu::HTMRegionGpu; // not pyclass methods; use internal constructor via Rust
|
| 103 |
+
// Easier: replicate the pipeline directly with SP + TM.
|
| 104 |
+
|
| 105 |
+
let cfg = SpatialPoolerConfig::default();
|
| 106 |
+
let bits = cfg.input_bits;
|
| 107 |
+
let n_cols = cfg.n_columns;
|
| 108 |
+
let cells_per_col = 32usize;
|
| 109 |
+
|
| 110 |
+
let cpu_for_gpu = SpatialPooler::new(SpatialPoolerConfig::default(), 314);
|
| 111 |
+
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_for_gpu).expect("gpu init");
|
| 112 |
+
let dev = sp.dev_ref().clone();
|
| 113 |
+
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col)
|
| 114 |
+
.expect("gpu tm init");
|
| 115 |
+
tm.reset().expect("tm reset");
|
| 116 |
+
|
| 117 |
+
// Build 3 fixed SDRs, feed them in a repeating sequence.
|
| 118 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 119 |
+
let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
|
| 120 |
+
let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
|
| 121 |
+
|
| 122 |
+
// Warm up SP so columns are stable per symbol.
|
| 123 |
+
for _ in 0..100 {
|
| 124 |
+
for s in &seqs {
|
| 125 |
+
let _ = sp.compute(s, true).expect("sp compute");
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Build a long input buffer: 100 repetitions of [A,B,C] = 300 steps.
|
| 130 |
+
let repeats = 100usize;
|
| 131 |
+
let t = repeats * 3;
|
| 132 |
+
let mut inputs_flat = vec![0u8; t * bits];
|
| 133 |
+
for r in 0..repeats {
|
| 134 |
+
for (i, s) in seqs.iter().enumerate() {
|
| 135 |
+
let off = (r * 3 + i) * bits;
|
| 136 |
+
inputs_flat[off..off + bits].copy_from_slice(s);
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod");
|
| 140 |
+
|
| 141 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc cols");
|
| 142 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
|
| 143 |
+
|
| 144 |
+
sp.step_batch_with_tm(
|
| 145 |
+
&inputs_dev,
|
| 146 |
+
t,
|
| 147 |
+
bits,
|
| 148 |
+
true,
|
| 149 |
+
&mut cols_dev,
|
| 150 |
+
&mut anom_dev,
|
| 151 |
+
&mut tm,
|
| 152 |
+
).expect("step_batch_with_tm");
|
| 153 |
+
|
| 154 |
+
let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 155 |
+
let cols: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
|
| 156 |
+
|
| 157 |
+
// Active column count per step must equal k for every step.
|
| 158 |
+
let k = ((cfg.sparsity * n_cols as f32).round() as usize).max(1);
|
| 159 |
+
for ti in 0..t {
|
| 160 |
+
let step_slice = &cols[ti * n_cols..(ti + 1) * n_cols];
|
| 161 |
+
let n_on = step_slice.iter().filter(|&&b| b != 0).count();
|
| 162 |
+
assert_eq!(n_on, k, "step {ti} has {n_on} active cols, expected {k}");
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// First repetition: anomaly should be near 1.0 (nothing predicted).
|
| 166 |
+
let early_avg: f32 = anom[3..9].iter().sum::<f32>() / 6.0;
|
| 167 |
+
// Last repetitions: anomaly should be noticeably lower.
|
| 168 |
+
let late_avg: f32 = anom[(t - 9)..t].iter().sum::<f32>() / 9.0;
|
| 169 |
+
eprintln!("gpu tm: early anomaly = {early_avg:.3}, late = {late_avg:.3}");
|
| 170 |
+
assert!(
|
| 171 |
+
late_avg < early_avg,
|
| 172 |
+
"GPU TM should reduce anomaly on repeating sequence: early={early_avg:.3}, late={late_avg:.3}"
|
| 173 |
+
);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/// Cluster-sync smoke test: verifies that the fused megakernel (which relies on
|
| 177 |
+
/// hardware `cluster::sync()` / grid-barrier on H100/H200 Hopper) completes
|
| 178 |
+
/// without deadlock when called with real HTM state, and that output shapes are
|
| 179 |
+
/// sane (no NaN / Inf in anomaly scores, active-column count in plausible range).
|
| 180 |
+
///
|
| 181 |
+
/// This is an *integration* test, not a synthetic micro-benchmark: it exercises
|
| 182 |
+
/// exactly the same `launch_fused` code path used in production, so any
|
| 183 |
+
/// deadlock in the cooperative-grid or DLB barrier would surface here.
|
| 184 |
+
///
|
| 185 |
+
/// Skips gracefully (with an eprintln) when no GPU is available — the test
|
| 186 |
+
/// binary returns exit-code 0 in that case so CI still passes.
|
| 187 |
+
#[test]
|
| 188 |
+
fn cluster_sync_smoke_test() {
|
| 189 |
+
// Build a tiny HTM region (1024 inputs, 256 columns, 4 cells/column).
|
| 190 |
+
// This keeps VRAM usage minimal while still exercising all kernel paths.
|
| 191 |
+
let input_bits = 1024usize;
|
| 192 |
+
let n_columns = 256usize;
|
| 193 |
+
let cells_per_col = 4usize;
|
| 194 |
+
|
| 195 |
+
// Probe cooperative launch attribute before doing any real work.
|
| 196 |
+
// CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH = 223 (added in CUDA 11.8 for Hopper).
|
| 197 |
+
// cudarc exposes raw attribute querying; we check cooperative launch (98)
|
| 198 |
+
// as the guard — cluster launch is a superset and not separately probed
|
| 199 |
+
// here since cudarc doesn't expose attribute 223 symbolically yet.
|
| 200 |
+
// On pre-Hopper hardware the DLB barrier path is used instead and the
|
| 201 |
+
// test still validates no deadlock on that path.
|
| 202 |
+
|
| 203 |
+
let make_cfg = || SpatialPoolerConfig {
|
| 204 |
+
input_bits,
|
| 205 |
+
n_columns,
|
| 206 |
+
sparsity: 0.04, // ~10 active cols out of 256
|
| 207 |
+
..SpatialPoolerConfig::default()
|
| 208 |
+
};
|
| 209 |
+
|
| 210 |
+
let cpu_ref = SpatialPooler::new(make_cfg(), 42);
|
| 211 |
+
|
| 212 |
+
let mut sp = match SpatialPoolerGpu::from_cpu(&cpu_ref) {
|
| 213 |
+
Ok(sp) => sp,
|
| 214 |
+
Err(e) => {
|
| 215 |
+
eprintln!("[cluster_sync_smoke_test] No GPU available ({e:?}) — skipping");
|
| 216 |
+
return;
|
| 217 |
+
}
|
| 218 |
+
};
|
| 219 |
+
|
| 220 |
+
let dev = sp.dev_ref().clone();
|
| 221 |
+
|
| 222 |
+
// Check cooperative launch support; skip with a clear message if absent.
|
| 223 |
+
let cooperative_ok = matches!(
|
| 224 |
+
dev.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
|
| 225 |
+
Ok(v) if v > 0
|
| 226 |
+
);
|
| 227 |
+
if !cooperative_ok {
|
| 228 |
+
eprintln!("[cluster_sync_smoke_test] CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH=0 — DLB path only, still running test");
|
| 229 |
+
// We continue — the DLB path is the production fallback and must not deadlock either.
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
let mut tm = match TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_col) {
|
| 233 |
+
Ok(tm) => tm,
|
| 234 |
+
Err(e) => {
|
| 235 |
+
eprintln!("[cluster_sync_smoke_test] TemporalMemoryGpu::new failed ({e:?}) — skipping");
|
| 236 |
+
return;
|
| 237 |
+
}
|
| 238 |
+
};
|
| 239 |
+
tm.reset().expect("tm reset");
|
| 240 |
+
|
| 241 |
+
let mut fused_st: FusedState = match FusedState::new(
|
| 242 |
+
dev.clone(),
|
| 243 |
+
n_columns,
|
| 244 |
+
cells_per_col,
|
| 245 |
+
sp.initial_threshold_estimate(),
|
| 246 |
+
) {
|
| 247 |
+
Ok(f) => f,
|
| 248 |
+
Err(e) => {
|
| 249 |
+
eprintln!("[cluster_sync_smoke_test] FusedState::new failed ({e:?}) — skipping");
|
| 250 |
+
return;
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
fused_st.reset().expect("fused reset");
|
| 254 |
+
|
| 255 |
+
// Build T=4 timesteps of all-zero input SDRs.
|
| 256 |
+
let t = 4usize;
|
| 257 |
+
let inputs_flat = vec![0u8; t * input_bits];
|
| 258 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs_flat).expect("htod inputs");
|
| 259 |
+
|
| 260 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_columns).expect("alloc cols");
|
| 261 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc anom");
|
| 262 |
+
|
| 263 |
+
// Execute with a 2-second timeout guard via a thread. If the kernel
|
| 264 |
+
// deadlocks, the parent test process times out and the CI job reports
|
| 265 |
+
// failure — we can't cancel a live CUDA kernel from Rust, but the
|
| 266 |
+
// launch_fused call itself must return within this window on any sane GPU.
|
| 267 |
+
//
|
| 268 |
+
// We run the kernel inline (not in a separate thread) because CUDA contexts
|
| 269 |
+
// are not safely shareable across threads without explicit multi-threading
|
| 270 |
+
// setup. The 2-second bound is enforced implicitly: if the kernel deadlocks,
|
| 271 |
+
// the test binary will hang and the CI timeout (typically 5 min) will kill it.
|
| 272 |
+
// For local dev, the deadlock would be immediately obvious.
|
| 273 |
+
|
| 274 |
+
launch_fused(
|
| 275 |
+
&mut sp,
|
| 276 |
+
&mut tm,
|
| 277 |
+
&mut fused_st,
|
| 278 |
+
&inputs_dev,
|
| 279 |
+
&mut cols_dev,
|
| 280 |
+
&mut anom_dev,
|
| 281 |
+
t,
|
| 282 |
+
input_bits,
|
| 283 |
+
false, // learn=false for determinism
|
| 284 |
+
).expect("launch_fused (cluster_sync_smoke_test): deadlock or CUDA error");
|
| 285 |
+
|
| 286 |
+
dev.synchronize().expect("device sync after launch_fused");
|
| 287 |
+
|
| 288 |
+
// --- Correctness assertions ---
|
| 289 |
+
|
| 290 |
+
let cols_host: Vec<u8> = dev.dtoh_sync_copy(&cols_dev).expect("d2h cols");
|
| 291 |
+
let anom_host: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 292 |
+
|
| 293 |
+
// Output buffers must be exactly the right size.
|
| 294 |
+
assert_eq!(cols_host.len(), t * n_columns, "cols buffer size mismatch");
|
| 295 |
+
assert_eq!(anom_host.len(), t, "anom buffer size mismatch");
|
| 296 |
+
|
| 297 |
+
// Anomaly scores must be finite (NaN/Inf indicates numerical blow-up).
|
| 298 |
+
for (i, &a) in anom_host.iter().enumerate() {
|
| 299 |
+
assert!(a.is_finite(), "anomaly[{i}] is not finite: {a}");
|
| 300 |
+
assert!(a >= 0.0 && a <= 1.0, "anomaly[{i}] out of [0,1]: {a}");
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// Active-column count per step: threshold-based inhibition, so 0 is
|
| 304 |
+
// possible on cold start (before thresholds calibrate), but we assert
|
| 305 |
+
// <= n_columns to catch buffer overruns or completely wrong output.
|
| 306 |
+
for ti in 0..t {
|
| 307 |
+
let n_on = cols_host[ti * n_columns..(ti + 1) * n_columns]
|
| 308 |
+
.iter()
|
| 309 |
+
.filter(|&&b| b != 0)
|
| 310 |
+
.count();
|
| 311 |
+
assert!(
|
| 312 |
+
n_on <= n_columns,
|
| 313 |
+
"step {ti}: active columns {n_on} > n_columns {n_columns} (buffer overrun?)"
|
| 314 |
+
);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
eprintln!(
|
| 318 |
+
"[cluster_sync_smoke_test] PASSED: T={t}, n_cols={n_columns}, \
|
| 319 |
+
input_bits={input_bits}, cooperative_supported={cooperative_ok}, \
|
| 320 |
+
anom={anom_host:?}"
|
| 321 |
+
);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/// Parity check: the CAI zero-copy path (`step_many_cuda`) must produce
|
| 325 |
+
/// bit-identical outputs to the numpy H2D/D2H path (`step_batch_with_tm`),
|
| 326 |
+
/// since the kernel pipeline is the same — only the I/O wrapping changes.
|
| 327 |
+
/// We skip the PyO3 CAI dict plumbing here and test the underlying
|
| 328 |
+
/// ManuallyDrop + upgrade_device_ptr pattern directly.
|
| 329 |
+
#[test]
|
| 330 |
+
fn gpu_cuda_vs_numpy_parity() {
|
| 331 |
+
use std::mem::ManuallyDrop;
|
| 332 |
+
|
| 333 |
+
let cfg = SpatialPoolerConfig::default();
|
| 334 |
+
let bits = cfg.input_bits;
|
| 335 |
+
let n_cols = cfg.n_columns;
|
| 336 |
+
let cells_per_col = 32usize;
|
| 337 |
+
|
| 338 |
+
// Build two identical (SP, TM) pairs from the same seed.
|
| 339 |
+
let build = || -> (SpatialPoolerGpu, TemporalMemoryGpu) {
|
| 340 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271828);
|
| 341 |
+
let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu init");
|
| 342 |
+
let dev = sp.dev_ref().clone();
|
| 343 |
+
let mut tm = TemporalMemoryGpu::new(dev, n_cols, cells_per_col).expect("tm init");
|
| 344 |
+
tm.reset().expect("tm reset");
|
| 345 |
+
(sp, tm)
|
| 346 |
+
};
|
| 347 |
+
|
| 348 |
+
// Deterministic SDR sequence.
|
| 349 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
|
| 350 |
+
let t = 32usize;
|
| 351 |
+
let mut inputs_flat = vec![0u8; t * bits];
|
| 352 |
+
for i in 0..t {
|
| 353 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 354 |
+
inputs_flat[i * bits..(i + 1) * bits].copy_from_slice(&sdr);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
// ---- Path A: owned CudaSlice (numpy-equivalent path) ----
|
| 358 |
+
let (mut sp_a, mut tm_a) = build();
|
| 359 |
+
let dev_a = sp_a.dev_ref().clone();
|
| 360 |
+
let inputs_a: CudaSlice<u8> = dev_a.htod_sync_copy(&inputs_flat).expect("htod");
|
| 361 |
+
let mut cols_a = dev_a.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_a");
|
| 362 |
+
let mut anom_a = dev_a.alloc_zeros::<f32>(t).expect("alloc anom_a");
|
| 363 |
+
sp_a.step_batch_with_tm(&inputs_a, t, bits, false, &mut cols_a, &mut anom_a, &mut tm_a)
|
| 364 |
+
.expect("owned step_batch_with_tm");
|
| 365 |
+
dev_a.synchronize().expect("sync a");
|
| 366 |
+
let cols_a_host: Vec<u8> = dev_a.dtoh_sync_copy(&cols_a).expect("d2h cols_a");
|
| 367 |
+
let anom_a_host: Vec<f32> = dev_a.dtoh_sync_copy(&anom_a).expect("d2h anom_a");
|
| 368 |
+
|
| 369 |
+
// ---- Path B: borrowed device pointers via upgrade_device_ptr ----
|
| 370 |
+
// We allocate fresh owned CudaSlices on a fresh device, then take their
|
| 371 |
+
// raw ptrs and re-wrap as ManuallyDrop borrowed views — mimicking what
|
| 372 |
+
// `step_many_cuda` does with torch-owned CUDA memory.
|
| 373 |
+
let (mut sp_b, mut tm_b) = build();
|
| 374 |
+
let dev_b = sp_b.dev_ref().clone();
|
| 375 |
+
let inputs_b_owned: CudaSlice<u8> = dev_b.htod_sync_copy(&inputs_flat).expect("htod");
|
| 376 |
+
let cols_b_owned = dev_b.alloc_zeros::<u8>(t * n_cols).expect("alloc cols_b");
|
| 377 |
+
let anom_b_owned = dev_b.alloc_zeros::<f32>(t).expect("alloc anom_b");
|
| 378 |
+
|
| 379 |
+
// Extract raw CUdeviceptrs (and leak the owners so their Drop doesn't free).
|
| 380 |
+
let inputs_ptr = inputs_b_owned.leak();
|
| 381 |
+
let cols_ptr = cols_b_owned.leak();
|
| 382 |
+
let anom_ptr = anom_b_owned.leak();
|
| 383 |
+
|
| 384 |
+
// Re-wrap as borrowed views.
|
| 385 |
+
let inputs_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) });
|
| 386 |
+
let mut cols_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) });
|
| 387 |
+
let mut anom_b = ManuallyDrop::new(unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) });
|
| 388 |
+
|
| 389 |
+
sp_b.step_batch_with_tm(&inputs_b, t, bits, false, &mut cols_b, &mut anom_b, &mut tm_b)
|
| 390 |
+
.expect("borrowed step_batch_with_tm");
|
| 391 |
+
dev_b.synchronize().expect("sync b");
|
| 392 |
+
// `ManuallyDrop` doesn't auto-coerce to `&CudaSlice<T>` for the DevicePtr
|
| 393 |
+
// trait bound on `dtoh_sync_copy`; explicit deref.
|
| 394 |
+
let cols_b_host: Vec<u8> = dev_b.dtoh_sync_copy(&*cols_b).expect("d2h cols_b");
|
| 395 |
+
let anom_b_host: Vec<f32> = dev_b.dtoh_sync_copy(&*anom_b).expect("d2h anom_b");
|
| 396 |
+
|
| 397 |
+
// Re-own so Drop actually frees (we leaked above).
|
| 398 |
+
let _inputs_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(inputs_ptr, t * bits) };
|
| 399 |
+
let _cols_owned_again = unsafe { dev_b.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) };
|
| 400 |
+
let _anom_owned_again = unsafe { dev_b.upgrade_device_ptr::<f32>(anom_ptr, t) };
|
| 401 |
+
|
| 402 |
+
assert_eq!(cols_a_host, cols_b_host, "active-column mask diverges between numpy and CAI paths");
|
| 403 |
+
assert_eq!(anom_a_host.len(), anom_b_host.len());
|
| 404 |
+
for (i, (a, b)) in anom_a_host.iter().zip(anom_b_host.iter()).enumerate() {
|
| 405 |
+
// Anomaly is a pure division of integer counts — bit-exact expected.
|
| 406 |
+
assert!((a - b).abs() < 1e-7, "anomaly mismatch at step {i}: a={a} b={b}");
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/// Fused kernel: threshold activation should converge to near target sparsity
|
| 411 |
+
/// after a short warmup. Acceptance: mean activation rate per step lands in
|
| 412 |
+
/// [0.3*target, 2.5*target] after 500-step warmup. Because the threshold
|
| 413 |
+
/// starts conservative (=2.0) and the per-column adaptation rate is slow
|
| 414 |
+
/// (0.001), we allow a generous band — the test asserts directional
|
| 415 |
+
/// convergence toward the target, not tight matching.
|
| 416 |
+
#[test]
|
| 417 |
+
fn gpu_threshold_converges_to_sparsity() {
|
| 418 |
+
let cfg = SpatialPoolerConfig::default();
|
| 419 |
+
let bits = cfg.input_bits;
|
| 420 |
+
let n_cols = cfg.n_columns;
|
| 421 |
+
let cells_per_col = 32usize;
|
| 422 |
+
let target = cfg.sparsity; // 0.02 = 40 cols expected
|
| 423 |
+
|
| 424 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 111);
|
| 425 |
+
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 426 |
+
let dev = sp.dev_ref().clone();
|
| 427 |
+
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
|
| 428 |
+
let mut fused = FusedState::new(
|
| 429 |
+
dev.clone(),
|
| 430 |
+
n_cols,
|
| 431 |
+
cells_per_col,
|
| 432 |
+
sp.initial_threshold_estimate(),
|
| 433 |
+
).expect("fused init");
|
| 434 |
+
tm.reset().expect("tm reset");
|
| 435 |
+
fused.reset().expect("fused reset");
|
| 436 |
+
|
| 437 |
+
// Warmup: 1000 random 2%-sparse SDRs.
|
| 438 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(31337);
|
| 439 |
+
let t_warm = 1000usize;
|
| 440 |
+
let mut inputs = vec![0u8; t_warm * bits];
|
| 441 |
+
for ti in 0..t_warm {
|
| 442 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 443 |
+
inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 444 |
+
}
|
| 445 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod");
|
| 446 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t_warm * n_cols).expect("alloc cols");
|
| 447 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t_warm).expect("alloc anom");
|
| 448 |
+
launch_fused(
|
| 449 |
+
&mut sp, &mut tm, &mut fused,
|
| 450 |
+
&inputs_dev, &mut cols_dev, &mut anom_dev,
|
| 451 |
+
t_warm, bits, true,
|
| 452 |
+
).expect("warmup launch");
|
| 453 |
+
dev.synchronize().expect("sync");
|
| 454 |
+
|
| 455 |
+
// Measurement pass: another 200 steps, measure mean activation.
|
| 456 |
+
let t_meas = 200usize;
|
| 457 |
+
let mut meas_inputs = vec![0u8; t_meas * bits];
|
| 458 |
+
for ti in 0..t_meas {
|
| 459 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 460 |
+
meas_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 461 |
+
}
|
| 462 |
+
let meas_dev: CudaSlice<u8> = dev.htod_sync_copy(&meas_inputs).expect("htod meas");
|
| 463 |
+
let mut meas_cols = dev.alloc_zeros::<u8>(t_meas * n_cols).expect("alloc meas cols");
|
| 464 |
+
let mut meas_anom = dev.alloc_zeros::<f32>(t_meas).expect("alloc meas anom");
|
| 465 |
+
launch_fused(
|
| 466 |
+
&mut sp, &mut tm, &mut fused,
|
| 467 |
+
&meas_dev, &mut meas_cols, &mut meas_anom,
|
| 468 |
+
t_meas, bits, true,
|
| 469 |
+
).expect("meas launch");
|
| 470 |
+
dev.synchronize().expect("sync meas");
|
| 471 |
+
|
| 472 |
+
let cols_host: Vec<u8> = dev.dtoh_sync_copy(&meas_cols).expect("d2h");
|
| 473 |
+
let mut step_counts = Vec::with_capacity(t_meas);
|
| 474 |
+
for ti in 0..t_meas {
|
| 475 |
+
let n_on = cols_host[ti*n_cols..(ti+1)*n_cols]
|
| 476 |
+
.iter().filter(|&&b| b != 0).count();
|
| 477 |
+
step_counts.push(n_on);
|
| 478 |
+
}
|
| 479 |
+
let mean_active: f64 = step_counts.iter().map(|&c| c as f64).sum::<f64>()
|
| 480 |
+
/ (t_meas as f64);
|
| 481 |
+
let target_active = target as f64 * n_cols as f64;
|
| 482 |
+
eprintln!(
|
| 483 |
+
"threshold-activation convergence: mean_active/step = {mean_active:.1} \
|
| 484 |
+
(target = {target_active:.1})"
|
| 485 |
+
);
|
| 486 |
+
// Very generous band — we just want to confirm the threshold loop is
|
| 487 |
+
// functioning (not diverged to 0 or to all-active).
|
| 488 |
+
assert!(
|
| 489 |
+
mean_active >= 0.25 * target_active && mean_active <= 4.0 * target_active,
|
| 490 |
+
"mean active {mean_active:.1} outside [0.25x, 4x] of target {target_active:.1}"
|
| 491 |
+
);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
/// Fused kernel: TM should learn a repeating sequence — anomaly decays.
|
| 495 |
+
#[test]
|
| 496 |
+
fn gpu_fused_tm_anomaly_decays_on_repeating_sequence() {
|
| 497 |
+
let cfg = SpatialPoolerConfig::default();
|
| 498 |
+
let bits = cfg.input_bits;
|
| 499 |
+
let n_cols = cfg.n_columns;
|
| 500 |
+
let cells_per_col = 32usize;
|
| 501 |
+
|
| 502 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 271);
|
| 503 |
+
let mut sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 504 |
+
let dev = sp.dev_ref().clone();
|
| 505 |
+
let mut tm = TemporalMemoryGpu::new(dev.clone(), n_cols, cells_per_col).expect("tm init");
|
| 506 |
+
let mut fused = FusedState::new(
|
| 507 |
+
dev.clone(),
|
| 508 |
+
n_cols,
|
| 509 |
+
cells_per_col,
|
| 510 |
+
sp.initial_threshold_estimate(),
|
| 511 |
+
).expect("fused init");
|
| 512 |
+
tm.reset().expect("tm reset");
|
| 513 |
+
fused.reset().expect("fused reset");
|
| 514 |
+
|
| 515 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 516 |
+
let make = |rng: &mut Xoshiro256PlusPlus| make_sdr(rng, bits, 0.02);
|
| 517 |
+
let seqs = [make(&mut rng), make(&mut rng), make(&mut rng)];
|
| 518 |
+
|
| 519 |
+
// Warmup SP threshold calibration with random SDRs first.
|
| 520 |
+
let warm = 300usize;
|
| 521 |
+
let mut warm_inputs = vec![0u8; warm * bits];
|
| 522 |
+
for ti in 0..warm {
|
| 523 |
+
let sdr = make_sdr(&mut rng, bits, 0.02);
|
| 524 |
+
warm_inputs[ti*bits..(ti+1)*bits].copy_from_slice(&sdr);
|
| 525 |
+
}
|
| 526 |
+
let warm_dev: CudaSlice<u8> = dev.htod_sync_copy(&warm_inputs).expect("htod warm");
|
| 527 |
+
let mut warm_cols = dev.alloc_zeros::<u8>(warm * n_cols).expect("alloc warm cols");
|
| 528 |
+
let mut warm_anom = dev.alloc_zeros::<f32>(warm).expect("alloc warm anom");
|
| 529 |
+
launch_fused(
|
| 530 |
+
&mut sp, &mut tm, &mut fused,
|
| 531 |
+
&warm_dev, &mut warm_cols, &mut warm_anom,
|
| 532 |
+
warm, bits, true,
|
| 533 |
+
).expect("warm launch");
|
| 534 |
+
dev.synchronize().expect("sync warm");
|
| 535 |
+
|
| 536 |
+
// Feed repeating A,B,C sequence for 100 reps.
|
| 537 |
+
let repeats = 100usize;
|
| 538 |
+
let t = repeats * 3;
|
| 539 |
+
let mut inputs = vec![0u8; t * bits];
|
| 540 |
+
for r in 0..repeats {
|
| 541 |
+
for (i, s) in seqs.iter().enumerate() {
|
| 542 |
+
let off = (r*3 + i) * bits;
|
| 543 |
+
inputs[off..off+bits].copy_from_slice(s);
|
| 544 |
+
}
|
| 545 |
+
}
|
| 546 |
+
let inputs_dev: CudaSlice<u8> = dev.htod_sync_copy(&inputs).expect("htod rep");
|
| 547 |
+
let mut cols_dev = dev.alloc_zeros::<u8>(t * n_cols).expect("alloc rep cols");
|
| 548 |
+
let mut anom_dev = dev.alloc_zeros::<f32>(t).expect("alloc rep anom");
|
| 549 |
+
launch_fused(
|
| 550 |
+
&mut sp, &mut tm, &mut fused,
|
| 551 |
+
&inputs_dev, &mut cols_dev, &mut anom_dev,
|
| 552 |
+
t, bits, true,
|
| 553 |
+
).expect("rep launch");
|
| 554 |
+
dev.synchronize().expect("sync rep");
|
| 555 |
+
|
| 556 |
+
let anom: Vec<f32> = dev.dtoh_sync_copy(&anom_dev).expect("d2h anom");
|
| 557 |
+
let early_avg: f32 = anom[3..12].iter().sum::<f32>() / 9.0;
|
| 558 |
+
let late_avg: f32 = anom[(t-9)..t].iter().sum::<f32>() / 9.0;
|
| 559 |
+
eprintln!("fused TM anomaly: early={early_avg:.3} late={late_avg:.3}");
|
| 560 |
+
assert!(
|
| 561 |
+
late_avg < early_avg,
|
| 562 |
+
"anomaly must decay: early={early_avg:.3} late={late_avg:.3}"
|
| 563 |
+
);
|
| 564 |
+
assert!(
|
| 565 |
+
late_avg < 0.5,
|
| 566 |
+
"late anomaly must be < 0.5 (got {late_avg:.3})"
|
| 567 |
+
);
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
#[test]
|
| 571 |
+
fn gpu_sp_yields_k_winners() {
|
| 572 |
+
let cfg = SpatialPoolerConfig::default();
|
| 573 |
+
let bits = cfg.input_bits;
|
| 574 |
+
let n = cfg.n_columns;
|
| 575 |
+
let expected_k = ((cfg.sparsity * n as f32).round() as usize).max(1);
|
| 576 |
+
let cpu = SpatialPooler::new(SpatialPoolerConfig::default(), 7);
|
| 577 |
+
let mut gpu = SpatialPoolerGpu::from_cpu(&cpu).expect("gpu init");
|
| 578 |
+
|
| 579 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(1);
|
| 580 |
+
for _ in 0..10 {
|
| 581 |
+
let sdr_u8 = make_sdr(&mut rng, bits, 0.02);
|
| 582 |
+
let active = gpu.compute(&sdr_u8, false).expect("gpu compute");
|
| 583 |
+
assert_eq!(active.len(), expected_k);
|
| 584 |
+
// Ensure sorted + unique.
|
| 585 |
+
for w in active.windows(2) {
|
| 586 |
+
assert!(w[0] < w[1], "duplicate or out-of-order winner indices");
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
#[test]
|
| 592 |
+
fn fused_launch_plan_uses_cooperative_grid_sync() {
|
| 593 |
+
let plan = plan_fused_launch(30, true, 30, None).expect("cooperative supported");
|
| 594 |
+
assert_eq!(plan.grid_dim_x, 16);
|
| 595 |
+
assert_eq!(plan.cooperative_grid_limit, 30);
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
#[test]
|
| 599 |
+
fn fused_launch_plan_scales_to_big_gpu() {
|
| 600 |
+
// H200-like: 132 SMs, high cooperative_grid_limit. Cap still applies.
|
| 601 |
+
let plan = plan_fused_launch(132, true, 1000, None).expect("cooperative supported");
|
| 602 |
+
assert_eq!(plan.grid_dim_x, 16); // capped by default override
|
| 603 |
+
let plan = plan_fused_launch(132, true, 1000, Some(64)).expect("cooperative supported");
|
| 604 |
+
assert_eq!(plan.grid_dim_x, 64); // override raises the cap
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
#[test]
|
| 608 |
+
fn fused_launch_plan_refuses_non_cooperative_devices() {
|
| 609 |
+
// The slow path was removed. Devices without cooperative launch fail fast.
|
| 610 |
+
let err = plan_fused_launch(30, false, 0, None).unwrap_err();
|
| 611 |
+
assert!(err.contains("cooperative launch"));
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
#[test]
|
| 615 |
+
fn fused_grid_cap_env_override_is_honored() {
|
| 616 |
+
let cfg = SpatialPoolerConfig::default();
|
| 617 |
+
let cpu_ref = SpatialPooler::new(SpatialPoolerConfig::default(), 5252);
|
| 618 |
+
let sp = SpatialPoolerGpu::from_cpu(&cpu_ref).expect("gpu sp init");
|
| 619 |
+
let dev = sp.dev_ref().clone();
|
| 620 |
+
|
| 621 |
+
unsafe { std::env::set_var("HTM_FUSED_GRID_CAP", "12"); }
|
| 622 |
+
let fused = FusedState::new(
|
| 623 |
+
dev.clone(),
|
| 624 |
+
cfg.n_columns,
|
| 625 |
+
32usize,
|
| 626 |
+
sp.initial_threshold_estimate(),
|
| 627 |
+
).expect("fused init");
|
| 628 |
+
unsafe { std::env::remove_var("HTM_FUSED_GRID_CAP"); }
|
| 629 |
+
|
| 630 |
+
let sm_count = match dev.attribute(
|
| 631 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
| 632 |
+
) {
|
| 633 |
+
Ok(v) => v as u32,
|
| 634 |
+
Err(_) => 16u32,
|
| 635 |
+
};
|
| 636 |
+
let expected = sm_count.max(1).min(12);
|
| 637 |
+
assert_eq!(
|
| 638 |
+
fused.grid_dim_x,
|
| 639 |
+
expected,
|
| 640 |
+
"fused grid cap env override ignored: expected min(sm_count, 12) = {expected}, got {}",
|
| 641 |
+
fused.grid_dim_x,
|
| 642 |
+
);
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
#[test]
|
| 646 |
+
fn batched_grid_plan_clamps_a10g_batch32_under_cooperative_limit() {
|
| 647 |
+
// A10G observed in HF Jobs: cooperative_grid_limit=400, B=32.
|
| 648 |
+
// grid_x=16 requests 512 cooperative blocks and fails; clamp to 12.
|
| 649 |
+
let grid_x = plan_batched_grid_dim(16, 400, 32, false).expect("fits after clamp");
|
| 650 |
+
assert_eq!(grid_x, 12);
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
#[test]
|
| 654 |
+
fn batched_grid_plan_reports_oversized_batch() {
|
| 655 |
+
let err = plan_batched_grid_dim(16, 31, 32, false).unwrap_err();
|
| 656 |
+
assert!(err.contains("COOPERATIVE_LAUNCH_TOO_LARGE"));
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
#[test]
|
| 660 |
+
fn batched_grid_plan_does_not_clamp_cluster_launches() {
|
| 661 |
+
let grid_x = plan_batched_grid_dim(16, 31, 32, true).expect("cluster path bypasses cooperative limit");
|
| 662 |
+
assert_eq!(grid_x, 16);
|
| 663 |
+
}
|
overlay/htm_rust/src/lib.rs
CHANGED
|
@@ -1,198 +1,198 @@
|
|
| 1 |
-
//! pyo3 bindings for HTMRegion (Numenta BAMI-spec HTM).
|
| 2 |
-
//!
|
| 3 |
-
//! Exposed class:
|
| 4 |
-
//! HTMRegion(input_bits, n_columns, cells_per_column, seed) -> HTMRegion
|
| 5 |
-
//! .step(input_sdr: np.ndarray[bool; input_bits], learn: bool = True)
|
| 6 |
-
//! -> (active_columns: np.ndarray[bool; n_columns],
|
| 7 |
-
//! active_cells: np.ndarray[bool; n_columns*cells_per_column],
|
| 8 |
-
//! predicted_cells:np.ndarray[bool; n_columns*cells_per_column],
|
| 9 |
-
//! anomaly: float)
|
| 10 |
-
//! .reset()
|
| 11 |
-
//! .n_columns -> int
|
| 12 |
-
//! .cells_per_column -> int
|
| 13 |
-
//! .input_bits -> int
|
| 14 |
-
//!
|
| 15 |
-
//! GIL is dropped during the heavy compute via `py.allow_threads(...)` so the
|
| 16 |
-
//! region is effectively `Send` for Python-side threading.
|
| 17 |
-
|
| 18 |
-
// pyo3 0.22 `#[pymethods]` expansion inserts an implicit `.into()` on the
|
| 19 |
-
// returned `Result` to normalise the error type, which clippy reports as
|
| 20 |
-
// `useless_conversion` when our methods already return `PyErr`. The emitted
|
| 21 |
-
// code sits outside the user-written impl, so item-level allows don't reach
|
| 22 |
-
// it; the module-wide allow is the documented workaround.
|
| 23 |
-
#![allow(clippy::useless_conversion)]
|
| 24 |
-
|
| 25 |
-
mod region;
|
| 26 |
-
mod sp;
|
| 27 |
-
mod tm;
|
| 28 |
-
|
| 29 |
-
#[cfg(feature = "gpu")]
|
| 30 |
-
mod gpu;
|
| 31 |
-
|
| 32 |
-
use numpy::{
|
| 33 |
-
IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2,
|
| 34 |
-
PyUntypedArrayMethods,
|
| 35 |
-
};
|
| 36 |
-
use pyo3::prelude::*;
|
| 37 |
-
|
| 38 |
-
use crate::region::HTMRegionCore;
|
| 39 |
-
|
| 40 |
-
/// Result of one HTM step: (active_columns, active_cells, predicted_cells, anomaly).
|
| 41 |
-
type StepOutput<'py> = (
|
| 42 |
-
Bound<'py, PyArray1<bool>>,
|
| 43 |
-
Bound<'py, PyArray1<bool>>,
|
| 44 |
-
Bound<'py, PyArray1<bool>>,
|
| 45 |
-
f32,
|
| 46 |
-
);
|
| 47 |
-
|
| 48 |
-
#[pyclass(module = "htm_rust")]
|
| 49 |
-
pub struct HTMRegion {
|
| 50 |
-
core: HTMRegionCore,
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
#[pymethods]
|
| 54 |
-
impl HTMRegion {
|
| 55 |
-
/// Create a new HTM region.
|
| 56 |
-
///
|
| 57 |
-
/// Args:
|
| 58 |
-
/// input_bits: length of binary input SDR
|
| 59 |
-
/// n_columns: number of mini-columns in the SP (e.g. 2048)
|
| 60 |
-
/// cells_per_column: cells per column in the TM (e.g. 32)
|
| 61 |
-
/// seed: RNG seed for reproducibility
|
| 62 |
-
#[new]
|
| 63 |
-
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
|
| 64 |
-
fn new(
|
| 65 |
-
input_bits: usize,
|
| 66 |
-
n_columns: usize,
|
| 67 |
-
cells_per_column: usize,
|
| 68 |
-
seed: u64,
|
| 69 |
-
) -> PyResult<Self> {
|
| 70 |
-
if input_bits == 0 {
|
| 71 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 72 |
-
"input_bits must be > 0",
|
| 73 |
-
));
|
| 74 |
-
}
|
| 75 |
-
if n_columns == 0 {
|
| 76 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 77 |
-
"n_columns must be > 0",
|
| 78 |
-
));
|
| 79 |
-
}
|
| 80 |
-
if cells_per_column == 0 {
|
| 81 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 82 |
-
"cells_per_column must be > 0",
|
| 83 |
-
));
|
| 84 |
-
}
|
| 85 |
-
Ok(Self {
|
| 86 |
-
core: HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed),
|
| 87 |
-
})
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
#[getter]
|
| 91 |
-
fn input_bits(&self) -> usize { self.core.sp.cfg.input_bits }
|
| 92 |
-
|
| 93 |
-
#[getter]
|
| 94 |
-
fn n_columns(&self) -> usize { self.core.sp.cfg.n_columns }
|
| 95 |
-
|
| 96 |
-
#[getter]
|
| 97 |
-
fn cells_per_column(&self) -> usize { self.core.tm.cfg.cells_per_column }
|
| 98 |
-
|
| 99 |
-
/// Process one timestep.
|
| 100 |
-
///
|
| 101 |
-
/// Args:
|
| 102 |
-
/// input_sdr: 1-D numpy boolean array of length `input_bits`.
|
| 103 |
-
/// learn: if True, update SP permanences and TM synapses.
|
| 104 |
-
///
|
| 105 |
-
/// Returns:
|
| 106 |
-
/// (active_columns, active_cells, predicted_cells, anomaly)
|
| 107 |
-
#[pyo3(signature = (input_sdr, learn=true))]
|
| 108 |
-
fn step<'py>(
|
| 109 |
-
&mut self,
|
| 110 |
-
py: Python<'py>,
|
| 111 |
-
input_sdr: PyReadonlyArray1<'py, bool>,
|
| 112 |
-
learn: bool,
|
| 113 |
-
) -> PyResult<StepOutput<'py>> {
|
| 114 |
-
let expected = self.core.sp.cfg.input_bits;
|
| 115 |
-
let slice = input_sdr.as_slice()?;
|
| 116 |
-
let got = slice.len();
|
| 117 |
-
if got != expected {
|
| 118 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 119 |
-
"input_sdr length {got} != expected input_bits {expected}",
|
| 120 |
-
)));
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
// Copy input to an owned Vec so we can drop the GIL.
|
| 124 |
-
let input_vec: Vec<bool> = slice.to_vec();
|
| 125 |
-
|
| 126 |
-
let (active_cols, active_cells, predicted_cells, anomaly) =
|
| 127 |
-
py.allow_threads(|| self.core.step(&input_vec, learn));
|
| 128 |
-
|
| 129 |
-
let a: Bound<'py, PyArray1<bool>> = active_cols.into_pyarray_bound(py);
|
| 130 |
-
let c: Bound<'py, PyArray1<bool>> = active_cells.into_pyarray_bound(py);
|
| 131 |
-
let p: Bound<'py, PyArray1<bool>> = predicted_cells.into_pyarray_bound(py);
|
| 132 |
-
Ok((a, c, p, anomaly))
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
/// Clear TM predictive state. Does NOT unlearn synapses.
|
| 136 |
-
fn reset(&mut self) { self.core.reset(); }
|
| 137 |
-
|
| 138 |
-
/// Process T timesteps from a `(T, input_bits)` bool ndarray.
|
| 139 |
-
///
|
| 140 |
-
/// Returns:
|
| 141 |
-
/// cols: (T, n_columns) float32 0/1 active-column mask
|
| 142 |
-
/// anom: (T,) float32 anomaly scores
|
| 143 |
-
///
|
| 144 |
-
/// Single GIL release for the whole pass, avoiding T × Python-call overhead.
|
| 145 |
-
#[pyo3(signature = (inputs, learn=true))]
|
| 146 |
-
fn step_many<'py>(
|
| 147 |
-
&mut self,
|
| 148 |
-
py: Python<'py>,
|
| 149 |
-
inputs: PyReadonlyArray2<'py, bool>,
|
| 150 |
-
learn: bool,
|
| 151 |
-
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
|
| 152 |
-
let shape = inputs.shape();
|
| 153 |
-
if shape.len() != 2 {
|
| 154 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 155 |
-
"inputs must be 2-D (T, input_bits)",
|
| 156 |
-
));
|
| 157 |
-
}
|
| 158 |
-
let t = shape[0];
|
| 159 |
-
let bits = shape[1];
|
| 160 |
-
let expected = self.core.sp.cfg.input_bits;
|
| 161 |
-
if bits != expected {
|
| 162 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 163 |
-
"inputs last dim {bits} != expected input_bits {expected}",
|
| 164 |
-
)));
|
| 165 |
-
}
|
| 166 |
-
let slice = inputs.as_slice()?;
|
| 167 |
-
let n_cols = self.core.sp.cfg.n_columns;
|
| 168 |
-
|
| 169 |
-
// Own the input buffer so we can drop the GIL.
|
| 170 |
-
let input_vec: Vec<bool> = slice.to_vec();
|
| 171 |
-
|
| 172 |
-
let (cols_u8, anom) =
|
| 173 |
-
py.allow_threads(|| self.core.step_many(&input_vec, bits, t, learn));
|
| 174 |
-
|
| 175 |
-
// Convert u8 mask to f32 for direct numpy consumption.
|
| 176 |
-
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
|
| 177 |
-
|
| 178 |
-
// Build (T, n_cols) and (T,) arrays.
|
| 179 |
-
let cols_arr =
|
| 180 |
-
numpy::PyArray1::from_vec_bound(py, cols_f32)
|
| 181 |
-
.reshape([t, n_cols])
|
| 182 |
-
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
|
| 183 |
-
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
|
| 184 |
-
Ok((cols_arr, anom_arr))
|
| 185 |
-
}
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
/// Python module entry point.
|
| 189 |
-
#[pymodule]
|
| 190 |
-
fn htm_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 191 |
-
m.add_class::<HTMRegion>()?;
|
| 192 |
-
#[cfg(feature = "gpu")]
|
| 193 |
-
{
|
| 194 |
-
gpu::register(m)?;
|
| 195 |
-
}
|
| 196 |
-
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
| 197 |
-
Ok(())
|
| 198 |
-
}
|
|
|
|
| 1 |
+
//! pyo3 bindings for HTMRegion (Numenta BAMI-spec HTM).
|
| 2 |
+
//!
|
| 3 |
+
//! Exposed class:
|
| 4 |
+
//! HTMRegion(input_bits, n_columns, cells_per_column, seed) -> HTMRegion
|
| 5 |
+
//! .step(input_sdr: np.ndarray[bool; input_bits], learn: bool = True)
|
| 6 |
+
//! -> (active_columns: np.ndarray[bool; n_columns],
|
| 7 |
+
//! active_cells: np.ndarray[bool; n_columns*cells_per_column],
|
| 8 |
+
//! predicted_cells:np.ndarray[bool; n_columns*cells_per_column],
|
| 9 |
+
//! anomaly: float)
|
| 10 |
+
//! .reset()
|
| 11 |
+
//! .n_columns -> int
|
| 12 |
+
//! .cells_per_column -> int
|
| 13 |
+
//! .input_bits -> int
|
| 14 |
+
//!
|
| 15 |
+
//! GIL is dropped during the heavy compute via `py.allow_threads(...)` so the
|
| 16 |
+
//! region is effectively `Send` for Python-side threading.
|
| 17 |
+
|
| 18 |
+
// pyo3 0.22 `#[pymethods]` expansion inserts an implicit `.into()` on the
|
| 19 |
+
// returned `Result` to normalise the error type, which clippy reports as
|
| 20 |
+
// `useless_conversion` when our methods already return `PyErr`. The emitted
|
| 21 |
+
// code sits outside the user-written impl, so item-level allows don't reach
|
| 22 |
+
// it; the module-wide allow is the documented workaround.
|
| 23 |
+
#![allow(clippy::useless_conversion)]
|
| 24 |
+
|
| 25 |
+
mod region;
|
| 26 |
+
mod sp;
|
| 27 |
+
mod tm;
|
| 28 |
+
|
| 29 |
+
#[cfg(feature = "gpu")]
|
| 30 |
+
mod gpu;
|
| 31 |
+
|
| 32 |
+
use numpy::{
|
| 33 |
+
IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2,
|
| 34 |
+
PyUntypedArrayMethods,
|
| 35 |
+
};
|
| 36 |
+
use pyo3::prelude::*;
|
| 37 |
+
|
| 38 |
+
use crate::region::HTMRegionCore;
|
| 39 |
+
|
| 40 |
+
/// Result of one HTM step: (active_columns, active_cells, predicted_cells, anomaly).
|
| 41 |
+
type StepOutput<'py> = (
|
| 42 |
+
Bound<'py, PyArray1<bool>>,
|
| 43 |
+
Bound<'py, PyArray1<bool>>,
|
| 44 |
+
Bound<'py, PyArray1<bool>>,
|
| 45 |
+
f32,
|
| 46 |
+
);
|
| 47 |
+
|
| 48 |
+
#[pyclass(module = "htm_rust")]
|
| 49 |
+
pub struct HTMRegion {
|
| 50 |
+
core: HTMRegionCore,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
#[pymethods]
|
| 54 |
+
impl HTMRegion {
|
| 55 |
+
/// Create a new HTM region.
|
| 56 |
+
///
|
| 57 |
+
/// Args:
|
| 58 |
+
/// input_bits: length of binary input SDR
|
| 59 |
+
/// n_columns: number of mini-columns in the SP (e.g. 2048)
|
| 60 |
+
/// cells_per_column: cells per column in the TM (e.g. 32)
|
| 61 |
+
/// seed: RNG seed for reproducibility
|
| 62 |
+
#[new]
|
| 63 |
+
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
|
| 64 |
+
fn new(
|
| 65 |
+
input_bits: usize,
|
| 66 |
+
n_columns: usize,
|
| 67 |
+
cells_per_column: usize,
|
| 68 |
+
seed: u64,
|
| 69 |
+
) -> PyResult<Self> {
|
| 70 |
+
if input_bits == 0 {
|
| 71 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 72 |
+
"input_bits must be > 0",
|
| 73 |
+
));
|
| 74 |
+
}
|
| 75 |
+
if n_columns == 0 {
|
| 76 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 77 |
+
"n_columns must be > 0",
|
| 78 |
+
));
|
| 79 |
+
}
|
| 80 |
+
if cells_per_column == 0 {
|
| 81 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 82 |
+
"cells_per_column must be > 0",
|
| 83 |
+
));
|
| 84 |
+
}
|
| 85 |
+
Ok(Self {
|
| 86 |
+
core: HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed),
|
| 87 |
+
})
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
#[getter]
|
| 91 |
+
fn input_bits(&self) -> usize { self.core.sp.cfg.input_bits }
|
| 92 |
+
|
| 93 |
+
#[getter]
|
| 94 |
+
fn n_columns(&self) -> usize { self.core.sp.cfg.n_columns }
|
| 95 |
+
|
| 96 |
+
#[getter]
|
| 97 |
+
fn cells_per_column(&self) -> usize { self.core.tm.cfg.cells_per_column }
|
| 98 |
+
|
| 99 |
+
/// Process one timestep.
|
| 100 |
+
///
|
| 101 |
+
/// Args:
|
| 102 |
+
/// input_sdr: 1-D numpy boolean array of length `input_bits`.
|
| 103 |
+
/// learn: if True, update SP permanences and TM synapses.
|
| 104 |
+
///
|
| 105 |
+
/// Returns:
|
| 106 |
+
/// (active_columns, active_cells, predicted_cells, anomaly)
|
| 107 |
+
#[pyo3(signature = (input_sdr, learn=true))]
|
| 108 |
+
fn step<'py>(
|
| 109 |
+
&mut self,
|
| 110 |
+
py: Python<'py>,
|
| 111 |
+
input_sdr: PyReadonlyArray1<'py, bool>,
|
| 112 |
+
learn: bool,
|
| 113 |
+
) -> PyResult<StepOutput<'py>> {
|
| 114 |
+
let expected = self.core.sp.cfg.input_bits;
|
| 115 |
+
let slice = input_sdr.as_slice()?;
|
| 116 |
+
let got = slice.len();
|
| 117 |
+
if got != expected {
|
| 118 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 119 |
+
"input_sdr length {got} != expected input_bits {expected}",
|
| 120 |
+
)));
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// Copy input to an owned Vec so we can drop the GIL.
|
| 124 |
+
let input_vec: Vec<bool> = slice.to_vec();
|
| 125 |
+
|
| 126 |
+
let (active_cols, active_cells, predicted_cells, anomaly) =
|
| 127 |
+
py.allow_threads(|| self.core.step(&input_vec, learn));
|
| 128 |
+
|
| 129 |
+
let a: Bound<'py, PyArray1<bool>> = active_cols.into_pyarray_bound(py);
|
| 130 |
+
let c: Bound<'py, PyArray1<bool>> = active_cells.into_pyarray_bound(py);
|
| 131 |
+
let p: Bound<'py, PyArray1<bool>> = predicted_cells.into_pyarray_bound(py);
|
| 132 |
+
Ok((a, c, p, anomaly))
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
/// Clear TM predictive state. Does NOT unlearn synapses.
|
| 136 |
+
fn reset(&mut self) { self.core.reset(); }
|
| 137 |
+
|
| 138 |
+
/// Process T timesteps from a `(T, input_bits)` bool ndarray.
|
| 139 |
+
///
|
| 140 |
+
/// Returns:
|
| 141 |
+
/// cols: (T, n_columns) float32 0/1 active-column mask
|
| 142 |
+
/// anom: (T,) float32 anomaly scores
|
| 143 |
+
///
|
| 144 |
+
/// Single GIL release for the whole pass, avoiding T × Python-call overhead.
|
| 145 |
+
#[pyo3(signature = (inputs, learn=true))]
|
| 146 |
+
fn step_many<'py>(
|
| 147 |
+
&mut self,
|
| 148 |
+
py: Python<'py>,
|
| 149 |
+
inputs: PyReadonlyArray2<'py, bool>,
|
| 150 |
+
learn: bool,
|
| 151 |
+
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
|
| 152 |
+
let shape = inputs.shape();
|
| 153 |
+
if shape.len() != 2 {
|
| 154 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 155 |
+
"inputs must be 2-D (T, input_bits)",
|
| 156 |
+
));
|
| 157 |
+
}
|
| 158 |
+
let t = shape[0];
|
| 159 |
+
let bits = shape[1];
|
| 160 |
+
let expected = self.core.sp.cfg.input_bits;
|
| 161 |
+
if bits != expected {
|
| 162 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 163 |
+
"inputs last dim {bits} != expected input_bits {expected}",
|
| 164 |
+
)));
|
| 165 |
+
}
|
| 166 |
+
let slice = inputs.as_slice()?;
|
| 167 |
+
let n_cols = self.core.sp.cfg.n_columns;
|
| 168 |
+
|
| 169 |
+
// Own the input buffer so we can drop the GIL.
|
| 170 |
+
let input_vec: Vec<bool> = slice.to_vec();
|
| 171 |
+
|
| 172 |
+
let (cols_u8, anom) =
|
| 173 |
+
py.allow_threads(|| self.core.step_many(&input_vec, bits, t, learn));
|
| 174 |
+
|
| 175 |
+
// Convert u8 mask to f32 for direct numpy consumption.
|
| 176 |
+
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
|
| 177 |
+
|
| 178 |
+
// Build (T, n_cols) and (T,) arrays.
|
| 179 |
+
let cols_arr =
|
| 180 |
+
numpy::PyArray1::from_vec_bound(py, cols_f32)
|
| 181 |
+
.reshape([t, n_cols])
|
| 182 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
|
| 183 |
+
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
|
| 184 |
+
Ok((cols_arr, anom_arr))
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/// Python module entry point.
|
| 189 |
+
#[pymodule]
|
| 190 |
+
fn htm_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 191 |
+
m.add_class::<HTMRegion>()?;
|
| 192 |
+
#[cfg(feature = "gpu")]
|
| 193 |
+
{
|
| 194 |
+
gpu::register(m)?;
|
| 195 |
+
}
|
| 196 |
+
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
| 197 |
+
Ok(())
|
| 198 |
+
}
|
overlay/htm_rust/src/region.rs
CHANGED
|
@@ -1,94 +1,94 @@
|
|
| 1 |
-
//! HTMRegion: compose SpatialPooler + TemporalMemory into a single step().
|
| 2 |
-
|
| 3 |
-
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 4 |
-
use crate::tm::{TemporalMemory, TemporalMemoryConfig};
|
| 5 |
-
|
| 6 |
-
pub struct HTMRegionCore {
|
| 7 |
-
pub sp: SpatialPooler,
|
| 8 |
-
pub tm: TemporalMemory,
|
| 9 |
-
}
|
| 10 |
-
|
| 11 |
-
impl HTMRegionCore {
|
| 12 |
-
pub fn new(
|
| 13 |
-
input_bits: usize,
|
| 14 |
-
n_columns: usize,
|
| 15 |
-
cells_per_column: usize,
|
| 16 |
-
seed: u64,
|
| 17 |
-
) -> Self {
|
| 18 |
-
let defaults = SpatialPoolerConfig::default();
|
| 19 |
-
let sp_cfg = SpatialPoolerConfig {
|
| 20 |
-
input_bits,
|
| 21 |
-
n_columns,
|
| 22 |
-
// Scale potential_radius to at most the input size.
|
| 23 |
-
potential_radius: defaults.potential_radius.min(input_bits),
|
| 24 |
-
..defaults
|
| 25 |
-
};
|
| 26 |
-
|
| 27 |
-
let tm_cfg = TemporalMemoryConfig {
|
| 28 |
-
n_columns,
|
| 29 |
-
cells_per_column,
|
| 30 |
-
..TemporalMemoryConfig::default()
|
| 31 |
-
};
|
| 32 |
-
|
| 33 |
-
Self {
|
| 34 |
-
sp: SpatialPooler::new(sp_cfg, seed),
|
| 35 |
-
tm: TemporalMemory::new(tm_cfg, seed.wrapping_add(0x9E3779B97F4A7C15)),
|
| 36 |
-
}
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
/// Process one timestep. Returns (active_columns_mask,
|
| 40 |
-
/// active_cells_mask, predicted_cells_mask, anomaly).
|
| 41 |
-
pub fn step(
|
| 42 |
-
&mut self,
|
| 43 |
-
input_sdr: &[bool],
|
| 44 |
-
learn: bool,
|
| 45 |
-
) -> (Vec<bool>, Vec<bool>, Vec<bool>, f32) {
|
| 46 |
-
let active_cols = self.sp.compute(input_sdr, learn);
|
| 47 |
-
|
| 48 |
-
let mut active_cols_mask = vec![false; self.sp.cfg.n_columns];
|
| 49 |
-
for &c in &active_cols {
|
| 50 |
-
active_cols_mask[c as usize] = true;
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
let anomaly = self.tm.compute(&active_cols, learn);
|
| 54 |
-
|
| 55 |
-
// active_cells and predictive_cells are stored as Vec<bool> already.
|
| 56 |
-
let active_cells_mask = self.tm.active_cells.clone();
|
| 57 |
-
let predicted_cells_mask = self.tm.predictive_cells.clone();
|
| 58 |
-
|
| 59 |
-
(active_cols_mask, active_cells_mask, predicted_cells_mask, anomaly)
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
pub fn reset(&mut self) {
|
| 63 |
-
self.tm.reset();
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
/// Process T timesteps in one call. Returns flat `(T*n_columns)` active-column
|
| 67 |
-
/// mask (u8 0/1) and `(T,)` anomaly scores.
|
| 68 |
-
///
|
| 69 |
-
/// Amortises the per-step Python round-trip for training: one GIL release,
|
| 70 |
-
/// one copy-out. Used by `HTMLayer.step_many`.
|
| 71 |
-
pub fn step_many(
|
| 72 |
-
&mut self,
|
| 73 |
-
inputs_flat: &[bool],
|
| 74 |
-
input_bits: usize,
|
| 75 |
-
t: usize,
|
| 76 |
-
learn: bool,
|
| 77 |
-
) -> (Vec<u8>, Vec<f32>) {
|
| 78 |
-
let n_cols = self.sp.cfg.n_columns;
|
| 79 |
-
debug_assert_eq!(inputs_flat.len(), t * input_bits);
|
| 80 |
-
let mut cols = vec![0u8; t * n_cols];
|
| 81 |
-
let mut anom = vec![0f32; t];
|
| 82 |
-
for ti in 0..t {
|
| 83 |
-
let off = ti * input_bits;
|
| 84 |
-
let input = &inputs_flat[off..off + input_bits];
|
| 85 |
-
let active_cols = self.sp.compute(input, learn);
|
| 86 |
-
let co = ti * n_cols;
|
| 87 |
-
for &c in &active_cols {
|
| 88 |
-
cols[co + c as usize] = 1;
|
| 89 |
-
}
|
| 90 |
-
anom[ti] = self.tm.compute(&active_cols, learn);
|
| 91 |
-
}
|
| 92 |
-
(cols, anom)
|
| 93 |
-
}
|
| 94 |
-
}
|
|
|
|
| 1 |
+
//! HTMRegion: compose SpatialPooler + TemporalMemory into a single step().
|
| 2 |
+
|
| 3 |
+
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 4 |
+
use crate::tm::{TemporalMemory, TemporalMemoryConfig};
|
| 5 |
+
|
| 6 |
+
pub struct HTMRegionCore {
|
| 7 |
+
pub sp: SpatialPooler,
|
| 8 |
+
pub tm: TemporalMemory,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
impl HTMRegionCore {
|
| 12 |
+
pub fn new(
|
| 13 |
+
input_bits: usize,
|
| 14 |
+
n_columns: usize,
|
| 15 |
+
cells_per_column: usize,
|
| 16 |
+
seed: u64,
|
| 17 |
+
) -> Self {
|
| 18 |
+
let defaults = SpatialPoolerConfig::default();
|
| 19 |
+
let sp_cfg = SpatialPoolerConfig {
|
| 20 |
+
input_bits,
|
| 21 |
+
n_columns,
|
| 22 |
+
// Scale potential_radius to at most the input size.
|
| 23 |
+
potential_radius: defaults.potential_radius.min(input_bits),
|
| 24 |
+
..defaults
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
let tm_cfg = TemporalMemoryConfig {
|
| 28 |
+
n_columns,
|
| 29 |
+
cells_per_column,
|
| 30 |
+
..TemporalMemoryConfig::default()
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
Self {
|
| 34 |
+
sp: SpatialPooler::new(sp_cfg, seed),
|
| 35 |
+
tm: TemporalMemory::new(tm_cfg, seed.wrapping_add(0x9E3779B97F4A7C15)),
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/// Process one timestep. Returns (active_columns_mask,
|
| 40 |
+
/// active_cells_mask, predicted_cells_mask, anomaly).
|
| 41 |
+
pub fn step(
|
| 42 |
+
&mut self,
|
| 43 |
+
input_sdr: &[bool],
|
| 44 |
+
learn: bool,
|
| 45 |
+
) -> (Vec<bool>, Vec<bool>, Vec<bool>, f32) {
|
| 46 |
+
let active_cols = self.sp.compute(input_sdr, learn);
|
| 47 |
+
|
| 48 |
+
let mut active_cols_mask = vec![false; self.sp.cfg.n_columns];
|
| 49 |
+
for &c in &active_cols {
|
| 50 |
+
active_cols_mask[c as usize] = true;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
let anomaly = self.tm.compute(&active_cols, learn);
|
| 54 |
+
|
| 55 |
+
// active_cells and predictive_cells are stored as Vec<bool> already.
|
| 56 |
+
let active_cells_mask = self.tm.active_cells.clone();
|
| 57 |
+
let predicted_cells_mask = self.tm.predictive_cells.clone();
|
| 58 |
+
|
| 59 |
+
(active_cols_mask, active_cells_mask, predicted_cells_mask, anomaly)
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
pub fn reset(&mut self) {
|
| 63 |
+
self.tm.reset();
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/// Process T timesteps in one call. Returns flat `(T*n_columns)` active-column
|
| 67 |
+
/// mask (u8 0/1) and `(T,)` anomaly scores.
|
| 68 |
+
///
|
| 69 |
+
/// Amortises the per-step Python round-trip for training: one GIL release,
|
| 70 |
+
/// one copy-out. Used by `HTMLayer.step_many`.
|
| 71 |
+
pub fn step_many(
|
| 72 |
+
&mut self,
|
| 73 |
+
inputs_flat: &[bool],
|
| 74 |
+
input_bits: usize,
|
| 75 |
+
t: usize,
|
| 76 |
+
learn: bool,
|
| 77 |
+
) -> (Vec<u8>, Vec<f32>) {
|
| 78 |
+
let n_cols = self.sp.cfg.n_columns;
|
| 79 |
+
debug_assert_eq!(inputs_flat.len(), t * input_bits);
|
| 80 |
+
let mut cols = vec![0u8; t * n_cols];
|
| 81 |
+
let mut anom = vec![0f32; t];
|
| 82 |
+
for ti in 0..t {
|
| 83 |
+
let off = ti * input_bits;
|
| 84 |
+
let input = &inputs_flat[off..off + input_bits];
|
| 85 |
+
let active_cols = self.sp.compute(input, learn);
|
| 86 |
+
let co = ti * n_cols;
|
| 87 |
+
for &c in &active_cols {
|
| 88 |
+
cols[co + c as usize] = 1;
|
| 89 |
+
}
|
| 90 |
+
anom[ti] = self.tm.compute(&active_cols, learn);
|
| 91 |
+
}
|
| 92 |
+
(cols, anom)
|
| 93 |
+
}
|
| 94 |
+
}
|
overlay/htm_rust/src/sp.rs
CHANGED
|
@@ -1,302 +1,302 @@
|
|
| 1 |
-
//! Numenta BAMI-spec Spatial Pooler.
|
| 2 |
-
//!
|
| 3 |
-
//! Implements:
|
| 4 |
-
//! - 2048 (configurable) mini-columns with proximal dendrites
|
| 5 |
-
//! - `potential_synapses` (default 40) synapses per column sampled from
|
| 6 |
-
//! `potential_radius` (default 1024) random input bits
|
| 7 |
-
//! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5
|
| 8 |
-
//! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008
|
| 9 |
-
//! - Global k-WTA inhibition (top `sparsity` fraction of columns)
|
| 10 |
-
//! - Boost factor with exponential duty-cycle tracking (Numenta formula)
|
| 11 |
-
//!
|
| 12 |
-
//! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017).
|
| 13 |
-
|
| 14 |
-
use rand::Rng;
|
| 15 |
-
use rand::SeedableRng;
|
| 16 |
-
use rand::seq::SliceRandom;
|
| 17 |
-
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 18 |
-
|
| 19 |
-
/// A single proximal dendrite: a sparse set of potential synapses onto
|
| 20 |
-
/// specific input bit indices, with per-synapse permanence values.
|
| 21 |
-
#[derive(Clone)]
|
| 22 |
-
pub struct ProximalDendrite {
|
| 23 |
-
/// Indices into the input SDR. Length == potential_synapses.
|
| 24 |
-
pub inputs: Vec<u32>,
|
| 25 |
-
/// Permanence for each potential synapse (same length as `inputs`).
|
| 26 |
-
pub perms: Vec<f32>,
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
-
pub struct SpatialPoolerConfig {
|
| 30 |
-
pub input_bits: usize,
|
| 31 |
-
pub n_columns: usize,
|
| 32 |
-
/// Size of the random input sample per column.
|
| 33 |
-
pub potential_radius: usize,
|
| 34 |
-
/// Number of potential synapses per column's proximal dendrite.
|
| 35 |
-
pub potential_synapses: usize,
|
| 36 |
-
pub connected_threshold: f32,
|
| 37 |
-
pub syn_perm_active_inc: f32,
|
| 38 |
-
pub syn_perm_inactive_dec: f32,
|
| 39 |
-
/// Target fraction of columns active per step (e.g. 0.02 for 2%).
|
| 40 |
-
pub sparsity: f32,
|
| 41 |
-
/// Duty cycle EMA period.
|
| 42 |
-
pub duty_cycle_period: f32,
|
| 43 |
-
/// Boost strength. Set to 0.0 to disable boosting.
|
| 44 |
-
pub boost_strength: f32,
|
| 45 |
-
/// Initial permanence span around the connected threshold.
|
| 46 |
-
pub init_perm_span: f32,
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
impl Default for SpatialPoolerConfig {
|
| 50 |
-
fn default() -> Self {
|
| 51 |
-
Self {
|
| 52 |
-
input_bits: 16384,
|
| 53 |
-
n_columns: 2048,
|
| 54 |
-
potential_radius: 1024,
|
| 55 |
-
potential_synapses: 40,
|
| 56 |
-
connected_threshold: 0.5,
|
| 57 |
-
syn_perm_active_inc: 0.04,
|
| 58 |
-
syn_perm_inactive_dec: 0.008,
|
| 59 |
-
sparsity: 0.02,
|
| 60 |
-
duty_cycle_period: 1000.0,
|
| 61 |
-
boost_strength: 1.0,
|
| 62 |
-
init_perm_span: 0.1,
|
| 63 |
-
}
|
| 64 |
-
}
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
pub struct SpatialPooler {
|
| 68 |
-
pub cfg: SpatialPoolerConfig,
|
| 69 |
-
pub columns: Vec<ProximalDendrite>,
|
| 70 |
-
/// Exponential moving average of "column was active" per step.
|
| 71 |
-
pub active_duty_cycle: Vec<f32>,
|
| 72 |
-
/// Exponential moving average of "overlap exceeded threshold" per step.
|
| 73 |
-
pub overlap_duty_cycle: Vec<f32>,
|
| 74 |
-
/// Boost factor per column.
|
| 75 |
-
pub boost: Vec<f32>,
|
| 76 |
-
rng: Xoshiro256PlusPlus,
|
| 77 |
-
iter_count: u64,
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
impl SpatialPooler {
|
| 81 |
-
pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self {
|
| 82 |
-
assert!(cfg.input_bits >= cfg.potential_radius,
|
| 83 |
-
"input_bits ({}) must be >= potential_radius ({})",
|
| 84 |
-
cfg.input_bits, cfg.potential_radius);
|
| 85 |
-
assert!(cfg.potential_radius >= cfg.potential_synapses,
|
| 86 |
-
"potential_radius ({}) must be >= potential_synapses ({})",
|
| 87 |
-
cfg.potential_radius, cfg.potential_synapses);
|
| 88 |
-
|
| 89 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
|
| 90 |
-
|
| 91 |
-
let mut columns = Vec::with_capacity(cfg.n_columns);
|
| 92 |
-
for _ in 0..cfg.n_columns {
|
| 93 |
-
// Sample `potential_radius` distinct input indices, then from those
|
| 94 |
-
// pick `potential_synapses` as the actual proximal synapses.
|
| 95 |
-
// Using partial Fisher-Yates via shuffle on a pool index range.
|
| 96 |
-
let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect();
|
| 97 |
-
// Efficient partial shuffle: swap the first `potential_radius`
|
| 98 |
-
// items with random items from the rest (Durstenfeld step).
|
| 99 |
-
for i in 0..cfg.potential_radius.min(pool.len()) {
|
| 100 |
-
let j = rng.gen_range(i..pool.len());
|
| 101 |
-
pool.swap(i, j);
|
| 102 |
-
}
|
| 103 |
-
let window = &mut pool[..cfg.potential_radius];
|
| 104 |
-
window.shuffle(&mut rng);
|
| 105 |
-
let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec();
|
| 106 |
-
inputs.sort_unstable();
|
| 107 |
-
|
| 108 |
-
let perms: Vec<f32> = (0..cfg.potential_synapses)
|
| 109 |
-
.map(|_| {
|
| 110 |
-
let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span);
|
| 111 |
-
(cfg.connected_threshold + delta).clamp(0.0, 1.0)
|
| 112 |
-
})
|
| 113 |
-
.collect();
|
| 114 |
-
|
| 115 |
-
columns.push(ProximalDendrite { inputs, perms });
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
let n = cfg.n_columns;
|
| 119 |
-
Self {
|
| 120 |
-
cfg,
|
| 121 |
-
columns,
|
| 122 |
-
active_duty_cycle: vec![0.0; n],
|
| 123 |
-
overlap_duty_cycle: vec![0.0; n],
|
| 124 |
-
boost: vec![1.0; n],
|
| 125 |
-
rng,
|
| 126 |
-
iter_count: 0,
|
| 127 |
-
}
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
/// Process one step: compute overlaps, inhibit, learn (if `learn`), update
|
| 131 |
-
/// duty cycles and boosts. Returns the set of active column indices.
|
| 132 |
-
pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> {
|
| 133 |
-
assert_eq!(input.len(), self.cfg.input_bits);
|
| 134 |
-
|
| 135 |
-
// 1) Overlap score per column (sum of CONNECTED synapses onto active inputs).
|
| 136 |
-
// Also track raw overlap for the overlap-duty-cycle.
|
| 137 |
-
let n = self.cfg.n_columns;
|
| 138 |
-
let mut overlaps: Vec<f32> = vec![0.0; n];
|
| 139 |
-
let mut raw_overlaps: Vec<u32> = vec![0; n];
|
| 140 |
-
|
| 141 |
-
for (ci, col) in self.columns.iter().enumerate() {
|
| 142 |
-
let mut s: u32 = 0;
|
| 143 |
-
for (syn_i, &inp) in col.inputs.iter().enumerate() {
|
| 144 |
-
if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold {
|
| 145 |
-
s += 1;
|
| 146 |
-
}
|
| 147 |
-
}
|
| 148 |
-
raw_overlaps[ci] = s;
|
| 149 |
-
overlaps[ci] = (s as f32) * self.boost[ci];
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
// 2) Global k-WTA inhibition. Select top-k columns by boosted overlap.
|
| 153 |
-
let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1);
|
| 154 |
-
let active: Vec<u32> = top_k(&overlaps, k);
|
| 155 |
-
|
| 156 |
-
// 3) Hebbian learning on active columns.
|
| 157 |
-
if learn {
|
| 158 |
-
for &ci in &active {
|
| 159 |
-
let col = &mut self.columns[ci as usize];
|
| 160 |
-
for (syn_i, &inp) in col.inputs.iter().enumerate() {
|
| 161 |
-
if input[inp as usize] {
|
| 162 |
-
col.perms[syn_i] =
|
| 163 |
-
(col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0);
|
| 164 |
-
} else {
|
| 165 |
-
col.perms[syn_i] =
|
| 166 |
-
(col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0);
|
| 167 |
-
}
|
| 168 |
-
}
|
| 169 |
-
}
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
// 4) Update duty cycles (EMA with period T -> alpha = 1/T).
|
| 173 |
-
let period = self.cfg.duty_cycle_period.max(1.0);
|
| 174 |
-
let alpha = 1.0 / period;
|
| 175 |
-
// Column is "overlapping enough" if raw overlap >= stimulus_threshold.
|
| 176 |
-
// Numenta uses min_overlap; we use 1 as a conservative floor.
|
| 177 |
-
let stimulus_threshold = 1.0_f32;
|
| 178 |
-
|
| 179 |
-
// Mark active columns.
|
| 180 |
-
let mut active_mask = vec![false; n];
|
| 181 |
-
for &ci in &active {
|
| 182 |
-
active_mask[ci as usize] = true;
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
for i in 0..n {
|
| 186 |
-
let active_sample = if active_mask[i] { 1.0 } else { 0.0 };
|
| 187 |
-
let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold {
|
| 188 |
-
1.0
|
| 189 |
-
} else {
|
| 190 |
-
0.0
|
| 191 |
-
};
|
| 192 |
-
self.active_duty_cycle[i] =
|
| 193 |
-
(1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample;
|
| 194 |
-
self.overlap_duty_cycle[i] =
|
| 195 |
-
(1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample;
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
// 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)).
|
| 199 |
-
// Under-used columns (duty < mean) get boost > 1.
|
| 200 |
-
if learn && self.cfg.boost_strength > 0.0 {
|
| 201 |
-
let mean_duty: f32 =
|
| 202 |
-
self.active_duty_cycle.iter().sum::<f32>() / (n as f32);
|
| 203 |
-
for i in 0..n {
|
| 204 |
-
self.boost[i] =
|
| 205 |
-
(-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp();
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
// 6) Permanence bump for chronically under-stimulated columns.
|
| 209 |
-
// If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood,
|
| 210 |
-
// bump all permanences by syn_perm_active_inc * 0.1.
|
| 211 |
-
// With global inhibition, "neighborhood" = all columns.
|
| 212 |
-
let max_overlap_duty = self
|
| 213 |
-
.overlap_duty_cycle
|
| 214 |
-
.iter()
|
| 215 |
-
.cloned()
|
| 216 |
-
.fold(0.0_f32, f32::max);
|
| 217 |
-
let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty;
|
| 218 |
-
if max_overlap_duty > 0.0 {
|
| 219 |
-
for i in 0..n {
|
| 220 |
-
if self.overlap_duty_cycle[i] < min_pct_overlap_duty {
|
| 221 |
-
for p in &mut self.columns[i].perms {
|
| 222 |
-
*p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0);
|
| 223 |
-
}
|
| 224 |
-
}
|
| 225 |
-
}
|
| 226 |
-
}
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
self.iter_count = self.iter_count.wrapping_add(1);
|
| 230 |
-
let _ = &mut self.rng; // suppress unused-mut when learn=false
|
| 231 |
-
active
|
| 232 |
-
}
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
/// Return the indices of the top-k values in `scores`.
|
| 236 |
-
/// Ties broken by index order. Output is sorted ascending.
|
| 237 |
-
fn top_k(scores: &[f32], k: usize) -> Vec<u32> {
|
| 238 |
-
if k == 0 {
|
| 239 |
-
return Vec::new();
|
| 240 |
-
}
|
| 241 |
-
let mut idx: Vec<u32> = (0..scores.len() as u32).collect();
|
| 242 |
-
// Partial sort: put top-k at the front by descending score.
|
| 243 |
-
// Use select_nth_unstable_by on (desc score, asc index).
|
| 244 |
-
idx.select_nth_unstable_by(k - 1, |&a, &b| {
|
| 245 |
-
let sa = scores[a as usize];
|
| 246 |
-
let sb = scores[b as usize];
|
| 247 |
-
// Reverse for descending.
|
| 248 |
-
match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) {
|
| 249 |
-
std::cmp::Ordering::Equal => a.cmp(&b),
|
| 250 |
-
ord => ord,
|
| 251 |
-
}
|
| 252 |
-
});
|
| 253 |
-
let mut winners: Vec<u32> = idx[..k].to_vec();
|
| 254 |
-
winners.sort_unstable();
|
| 255 |
-
winners
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
// ---------------------------------------------------------------------------
|
| 259 |
-
// Tests
|
| 260 |
-
// ---------------------------------------------------------------------------
|
| 261 |
-
|
| 262 |
-
#[cfg(test)]
|
| 263 |
-
mod tests {
|
| 264 |
-
use super::*;
|
| 265 |
-
use rand::Rng;
|
| 266 |
-
use rand::SeedableRng;
|
| 267 |
-
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 268 |
-
|
| 269 |
-
#[test]
|
| 270 |
-
fn sp_sparsity_exact_2pct() {
|
| 271 |
-
// BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41.
|
| 272 |
-
// The SP must produce *exactly* that count, no more, no less, and with
|
| 273 |
-
// no duplicate indices.
|
| 274 |
-
let cfg = SpatialPoolerConfig::default();
|
| 275 |
-
let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize;
|
| 276 |
-
assert!(expected_k > 0);
|
| 277 |
-
|
| 278 |
-
let input_bits = cfg.input_bits;
|
| 279 |
-
let mut sp = SpatialPooler::new(cfg, 42);
|
| 280 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 281 |
-
|
| 282 |
-
for _ in 0..100 {
|
| 283 |
-
// 2% sparse random input SDR.
|
| 284 |
-
let on_bits = (0.02 * input_bits as f32) as usize;
|
| 285 |
-
let mut sdr = vec![false; input_bits];
|
| 286 |
-
for _ in 0..on_bits {
|
| 287 |
-
let i = rng.gen_range(0..input_bits);
|
| 288 |
-
sdr[i] = true;
|
| 289 |
-
}
|
| 290 |
-
let active = sp.compute(&sdr, true);
|
| 291 |
-
assert_eq!(
|
| 292 |
-
active.len(),
|
| 293 |
-
expected_k,
|
| 294 |
-
"SP must emit exactly {expected_k} active columns"
|
| 295 |
-
);
|
| 296 |
-
let mut a = active.clone();
|
| 297 |
-
a.sort_unstable();
|
| 298 |
-
a.dedup();
|
| 299 |
-
assert_eq!(a.len(), expected_k);
|
| 300 |
-
}
|
| 301 |
-
}
|
| 302 |
-
}
|
|
|
|
| 1 |
+
//! Numenta BAMI-spec Spatial Pooler.
|
| 2 |
+
//!
|
| 3 |
+
//! Implements:
|
| 4 |
+
//! - 2048 (configurable) mini-columns with proximal dendrites
|
| 5 |
+
//! - `potential_synapses` (default 40) synapses per column sampled from
|
| 6 |
+
//! `potential_radius` (default 1024) random input bits
|
| 7 |
+
//! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5
|
| 8 |
+
//! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008
|
| 9 |
+
//! - Global k-WTA inhibition (top `sparsity` fraction of columns)
|
| 10 |
+
//! - Boost factor with exponential duty-cycle tracking (Numenta formula)
|
| 11 |
+
//!
|
| 12 |
+
//! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017).
|
| 13 |
+
|
| 14 |
+
use rand::Rng;
|
| 15 |
+
use rand::SeedableRng;
|
| 16 |
+
use rand::seq::SliceRandom;
|
| 17 |
+
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 18 |
+
|
| 19 |
+
/// A single proximal dendrite: a sparse set of potential synapses onto
|
| 20 |
+
/// specific input bit indices, with per-synapse permanence values.
|
| 21 |
+
#[derive(Clone)]
|
| 22 |
+
pub struct ProximalDendrite {
|
| 23 |
+
/// Indices into the input SDR. Length == potential_synapses.
|
| 24 |
+
pub inputs: Vec<u32>,
|
| 25 |
+
/// Permanence for each potential synapse (same length as `inputs`).
|
| 26 |
+
pub perms: Vec<f32>,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
pub struct SpatialPoolerConfig {
|
| 30 |
+
pub input_bits: usize,
|
| 31 |
+
pub n_columns: usize,
|
| 32 |
+
/// Size of the random input sample per column.
|
| 33 |
+
pub potential_radius: usize,
|
| 34 |
+
/// Number of potential synapses per column's proximal dendrite.
|
| 35 |
+
pub potential_synapses: usize,
|
| 36 |
+
pub connected_threshold: f32,
|
| 37 |
+
pub syn_perm_active_inc: f32,
|
| 38 |
+
pub syn_perm_inactive_dec: f32,
|
| 39 |
+
/// Target fraction of columns active per step (e.g. 0.02 for 2%).
|
| 40 |
+
pub sparsity: f32,
|
| 41 |
+
/// Duty cycle EMA period.
|
| 42 |
+
pub duty_cycle_period: f32,
|
| 43 |
+
/// Boost strength. Set to 0.0 to disable boosting.
|
| 44 |
+
pub boost_strength: f32,
|
| 45 |
+
/// Initial permanence span around the connected threshold.
|
| 46 |
+
pub init_perm_span: f32,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
impl Default for SpatialPoolerConfig {
|
| 50 |
+
fn default() -> Self {
|
| 51 |
+
Self {
|
| 52 |
+
input_bits: 16384,
|
| 53 |
+
n_columns: 2048,
|
| 54 |
+
potential_radius: 1024,
|
| 55 |
+
potential_synapses: 40,
|
| 56 |
+
connected_threshold: 0.5,
|
| 57 |
+
syn_perm_active_inc: 0.04,
|
| 58 |
+
syn_perm_inactive_dec: 0.008,
|
| 59 |
+
sparsity: 0.02,
|
| 60 |
+
duty_cycle_period: 1000.0,
|
| 61 |
+
boost_strength: 1.0,
|
| 62 |
+
init_perm_span: 0.1,
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
pub struct SpatialPooler {
|
| 68 |
+
pub cfg: SpatialPoolerConfig,
|
| 69 |
+
pub columns: Vec<ProximalDendrite>,
|
| 70 |
+
/// Exponential moving average of "column was active" per step.
|
| 71 |
+
pub active_duty_cycle: Vec<f32>,
|
| 72 |
+
/// Exponential moving average of "overlap exceeded threshold" per step.
|
| 73 |
+
pub overlap_duty_cycle: Vec<f32>,
|
| 74 |
+
/// Boost factor per column.
|
| 75 |
+
pub boost: Vec<f32>,
|
| 76 |
+
rng: Xoshiro256PlusPlus,
|
| 77 |
+
iter_count: u64,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
impl SpatialPooler {
|
| 81 |
+
pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self {
|
| 82 |
+
assert!(cfg.input_bits >= cfg.potential_radius,
|
| 83 |
+
"input_bits ({}) must be >= potential_radius ({})",
|
| 84 |
+
cfg.input_bits, cfg.potential_radius);
|
| 85 |
+
assert!(cfg.potential_radius >= cfg.potential_synapses,
|
| 86 |
+
"potential_radius ({}) must be >= potential_synapses ({})",
|
| 87 |
+
cfg.potential_radius, cfg.potential_synapses);
|
| 88 |
+
|
| 89 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
|
| 90 |
+
|
| 91 |
+
let mut columns = Vec::with_capacity(cfg.n_columns);
|
| 92 |
+
for _ in 0..cfg.n_columns {
|
| 93 |
+
// Sample `potential_radius` distinct input indices, then from those
|
| 94 |
+
// pick `potential_synapses` as the actual proximal synapses.
|
| 95 |
+
// Using partial Fisher-Yates via shuffle on a pool index range.
|
| 96 |
+
let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect();
|
| 97 |
+
// Efficient partial shuffle: swap the first `potential_radius`
|
| 98 |
+
// items with random items from the rest (Durstenfeld step).
|
| 99 |
+
for i in 0..cfg.potential_radius.min(pool.len()) {
|
| 100 |
+
let j = rng.gen_range(i..pool.len());
|
| 101 |
+
pool.swap(i, j);
|
| 102 |
+
}
|
| 103 |
+
let window = &mut pool[..cfg.potential_radius];
|
| 104 |
+
window.shuffle(&mut rng);
|
| 105 |
+
let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec();
|
| 106 |
+
inputs.sort_unstable();
|
| 107 |
+
|
| 108 |
+
let perms: Vec<f32> = (0..cfg.potential_synapses)
|
| 109 |
+
.map(|_| {
|
| 110 |
+
let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span);
|
| 111 |
+
(cfg.connected_threshold + delta).clamp(0.0, 1.0)
|
| 112 |
+
})
|
| 113 |
+
.collect();
|
| 114 |
+
|
| 115 |
+
columns.push(ProximalDendrite { inputs, perms });
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
let n = cfg.n_columns;
|
| 119 |
+
Self {
|
| 120 |
+
cfg,
|
| 121 |
+
columns,
|
| 122 |
+
active_duty_cycle: vec![0.0; n],
|
| 123 |
+
overlap_duty_cycle: vec![0.0; n],
|
| 124 |
+
boost: vec![1.0; n],
|
| 125 |
+
rng,
|
| 126 |
+
iter_count: 0,
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Process one step: compute overlaps, inhibit, learn (if `learn`), update
|
| 131 |
+
/// duty cycles and boosts. Returns the set of active column indices.
|
| 132 |
+
pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> {
|
| 133 |
+
assert_eq!(input.len(), self.cfg.input_bits);
|
| 134 |
+
|
| 135 |
+
// 1) Overlap score per column (sum of CONNECTED synapses onto active inputs).
|
| 136 |
+
// Also track raw overlap for the overlap-duty-cycle.
|
| 137 |
+
let n = self.cfg.n_columns;
|
| 138 |
+
let mut overlaps: Vec<f32> = vec![0.0; n];
|
| 139 |
+
let mut raw_overlaps: Vec<u32> = vec![0; n];
|
| 140 |
+
|
| 141 |
+
for (ci, col) in self.columns.iter().enumerate() {
|
| 142 |
+
let mut s: u32 = 0;
|
| 143 |
+
for (syn_i, &inp) in col.inputs.iter().enumerate() {
|
| 144 |
+
if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold {
|
| 145 |
+
s += 1;
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
raw_overlaps[ci] = s;
|
| 149 |
+
overlaps[ci] = (s as f32) * self.boost[ci];
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// 2) Global k-WTA inhibition. Select top-k columns by boosted overlap.
|
| 153 |
+
let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1);
|
| 154 |
+
let active: Vec<u32> = top_k(&overlaps, k);
|
| 155 |
+
|
| 156 |
+
// 3) Hebbian learning on active columns.
|
| 157 |
+
if learn {
|
| 158 |
+
for &ci in &active {
|
| 159 |
+
let col = &mut self.columns[ci as usize];
|
| 160 |
+
for (syn_i, &inp) in col.inputs.iter().enumerate() {
|
| 161 |
+
if input[inp as usize] {
|
| 162 |
+
col.perms[syn_i] =
|
| 163 |
+
(col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0);
|
| 164 |
+
} else {
|
| 165 |
+
col.perms[syn_i] =
|
| 166 |
+
(col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0);
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// 4) Update duty cycles (EMA with period T -> alpha = 1/T).
|
| 173 |
+
let period = self.cfg.duty_cycle_period.max(1.0);
|
| 174 |
+
let alpha = 1.0 / period;
|
| 175 |
+
// Column is "overlapping enough" if raw overlap >= stimulus_threshold.
|
| 176 |
+
// Numenta uses min_overlap; we use 1 as a conservative floor.
|
| 177 |
+
let stimulus_threshold = 1.0_f32;
|
| 178 |
+
|
| 179 |
+
// Mark active columns.
|
| 180 |
+
let mut active_mask = vec![false; n];
|
| 181 |
+
for &ci in &active {
|
| 182 |
+
active_mask[ci as usize] = true;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
for i in 0..n {
|
| 186 |
+
let active_sample = if active_mask[i] { 1.0 } else { 0.0 };
|
| 187 |
+
let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold {
|
| 188 |
+
1.0
|
| 189 |
+
} else {
|
| 190 |
+
0.0
|
| 191 |
+
};
|
| 192 |
+
self.active_duty_cycle[i] =
|
| 193 |
+
(1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample;
|
| 194 |
+
self.overlap_duty_cycle[i] =
|
| 195 |
+
(1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)).
|
| 199 |
+
// Under-used columns (duty < mean) get boost > 1.
|
| 200 |
+
if learn && self.cfg.boost_strength > 0.0 {
|
| 201 |
+
let mean_duty: f32 =
|
| 202 |
+
self.active_duty_cycle.iter().sum::<f32>() / (n as f32);
|
| 203 |
+
for i in 0..n {
|
| 204 |
+
self.boost[i] =
|
| 205 |
+
(-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp();
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// 6) Permanence bump for chronically under-stimulated columns.
|
| 209 |
+
// If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood,
|
| 210 |
+
// bump all permanences by syn_perm_active_inc * 0.1.
|
| 211 |
+
// With global inhibition, "neighborhood" = all columns.
|
| 212 |
+
let max_overlap_duty = self
|
| 213 |
+
.overlap_duty_cycle
|
| 214 |
+
.iter()
|
| 215 |
+
.cloned()
|
| 216 |
+
.fold(0.0_f32, f32::max);
|
| 217 |
+
let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty;
|
| 218 |
+
if max_overlap_duty > 0.0 {
|
| 219 |
+
for i in 0..n {
|
| 220 |
+
if self.overlap_duty_cycle[i] < min_pct_overlap_duty {
|
| 221 |
+
for p in &mut self.columns[i].perms {
|
| 222 |
+
*p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0);
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
self.iter_count = self.iter_count.wrapping_add(1);
|
| 230 |
+
let _ = &mut self.rng; // suppress unused-mut when learn=false
|
| 231 |
+
active
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
/// Return the indices of the top-k values in `scores`.
|
| 236 |
+
/// Ties broken by index order. Output is sorted ascending.
|
| 237 |
+
fn top_k(scores: &[f32], k: usize) -> Vec<u32> {
|
| 238 |
+
if k == 0 {
|
| 239 |
+
return Vec::new();
|
| 240 |
+
}
|
| 241 |
+
let mut idx: Vec<u32> = (0..scores.len() as u32).collect();
|
| 242 |
+
// Partial sort: put top-k at the front by descending score.
|
| 243 |
+
// Use select_nth_unstable_by on (desc score, asc index).
|
| 244 |
+
idx.select_nth_unstable_by(k - 1, |&a, &b| {
|
| 245 |
+
let sa = scores[a as usize];
|
| 246 |
+
let sb = scores[b as usize];
|
| 247 |
+
// Reverse for descending.
|
| 248 |
+
match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) {
|
| 249 |
+
std::cmp::Ordering::Equal => a.cmp(&b),
|
| 250 |
+
ord => ord,
|
| 251 |
+
}
|
| 252 |
+
});
|
| 253 |
+
let mut winners: Vec<u32> = idx[..k].to_vec();
|
| 254 |
+
winners.sort_unstable();
|
| 255 |
+
winners
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// ---------------------------------------------------------------------------
|
| 259 |
+
// Tests
|
| 260 |
+
// ---------------------------------------------------------------------------
|
| 261 |
+
|
| 262 |
+
#[cfg(test)]
|
| 263 |
+
mod tests {
|
| 264 |
+
use super::*;
|
| 265 |
+
use rand::Rng;
|
| 266 |
+
use rand::SeedableRng;
|
| 267 |
+
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 268 |
+
|
| 269 |
+
#[test]
|
| 270 |
+
fn sp_sparsity_exact_2pct() {
|
| 271 |
+
// BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41.
|
| 272 |
+
// The SP must produce *exactly* that count, no more, no less, and with
|
| 273 |
+
// no duplicate indices.
|
| 274 |
+
let cfg = SpatialPoolerConfig::default();
|
| 275 |
+
let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize;
|
| 276 |
+
assert!(expected_k > 0);
|
| 277 |
+
|
| 278 |
+
let input_bits = cfg.input_bits;
|
| 279 |
+
let mut sp = SpatialPooler::new(cfg, 42);
|
| 280 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
|
| 281 |
+
|
| 282 |
+
for _ in 0..100 {
|
| 283 |
+
// 2% sparse random input SDR.
|
| 284 |
+
let on_bits = (0.02 * input_bits as f32) as usize;
|
| 285 |
+
let mut sdr = vec![false; input_bits];
|
| 286 |
+
for _ in 0..on_bits {
|
| 287 |
+
let i = rng.gen_range(0..input_bits);
|
| 288 |
+
sdr[i] = true;
|
| 289 |
+
}
|
| 290 |
+
let active = sp.compute(&sdr, true);
|
| 291 |
+
assert_eq!(
|
| 292 |
+
active.len(),
|
| 293 |
+
expected_k,
|
| 294 |
+
"SP must emit exactly {expected_k} active columns"
|
| 295 |
+
);
|
| 296 |
+
let mut a = active.clone();
|
| 297 |
+
a.sort_unstable();
|
| 298 |
+
a.dedup();
|
| 299 |
+
assert_eq!(a.len(), expected_k);
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
}
|
overlay/htm_rust/src/tm.rs
CHANGED
|
@@ -1,545 +1,545 @@
|
|
| 1 |
-
//! Numenta BAMI-spec Temporal Memory.
|
| 2 |
-
//!
|
| 3 |
-
//! Key parameters (Numenta defaults):
|
| 4 |
-
//! - cells_per_column = 32
|
| 5 |
-
//! - max_segments_per_cell = 255
|
| 6 |
-
//! - max_synapses_per_segment = 32
|
| 7 |
-
//! - activation_threshold = 15 (CONNECTED synapses onto active cells)
|
| 8 |
-
//! - learning_threshold = 13 (POTENTIAL synapses onto active cells)
|
| 9 |
-
//! (often called `minThreshold` / match threshold in BAMI)
|
| 10 |
-
//! - initial_permanence = 0.21
|
| 11 |
-
//! - connected_permanence = 0.50
|
| 12 |
-
//! - permanence_increment = 0.10
|
| 13 |
-
//! - permanence_decrement = 0.10
|
| 14 |
-
//! - predicted_segment_decrement = 0.10 (decay for segments that predicted
|
| 15 |
-
//! inactive columns; called `predictedSegmentDecrement` in BAMI)
|
| 16 |
-
//! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg)
|
| 17 |
-
//!
|
| 18 |
-
//! Algorithm (one step):
|
| 19 |
-
//! Given `active_columns` from the Spatial Pooler, and segment activity
|
| 20 |
-
//! caches `active_segments` and `matching_segments` computed *at the end of
|
| 21 |
-
//! the previous step*:
|
| 22 |
-
//!
|
| 23 |
-
//! 1. For each active column:
|
| 24 |
-
//! - If it contains any predicted cell (any cell with an active segment
|
| 25 |
-
//! from the previous depolarization), mark those cells active and
|
| 26 |
-
//! learn on the segment that predicted it.
|
| 27 |
-
//! - Else BURST the column: mark all cells in it active, and grow a new
|
| 28 |
-
//! segment on the best-matching cell in the column (or, if none,
|
| 29 |
-
//! on the cell with the fewest segments).
|
| 30 |
-
//! 2. For every column that was predicted but did NOT become active
|
| 31 |
-
//! (matching segments on inactive columns), apply the
|
| 32 |
-
//! `predicted_segment_decrement` decay so spurious predictions fade.
|
| 33 |
-
//! 3. Winner cells = active cells chosen for learning (1 per active column).
|
| 34 |
-
//! 4. Compute segment activity for NEXT step:
|
| 35 |
-
//! - A segment's CONNECTED activity = #synapses with perm >= connected_perm
|
| 36 |
-
//! whose presynaptic cell is in `active_cells`. If >= activation_threshold
|
| 37 |
-
//! -> segment is "active" -> its cell is "predicted".
|
| 38 |
-
//! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is
|
| 39 |
-
//! in `active_cells` (regardless of permanence). If >= learning_threshold
|
| 40 |
-
//! -> segment is "matching".
|
| 41 |
-
//!
|
| 42 |
-
//! Anomaly score = (active columns with no prior predicted cells)
|
| 43 |
-
//! / (# active columns).
|
| 44 |
-
|
| 45 |
-
use rand::Rng;
|
| 46 |
-
use rand::SeedableRng;
|
| 47 |
-
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 48 |
-
|
| 49 |
-
type CellIdx = u32;
|
| 50 |
-
type SegmentIdx = u32;
|
| 51 |
-
|
| 52 |
-
#[derive(Clone)]
|
| 53 |
-
pub struct Synapse {
|
| 54 |
-
pub presynaptic_cell: CellIdx,
|
| 55 |
-
pub permanence: f32,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
#[derive(Clone)]
|
| 59 |
-
pub struct Segment {
|
| 60 |
-
pub cell: CellIdx,
|
| 61 |
-
pub synapses: Vec<Synapse>,
|
| 62 |
-
/// Cached counters; recomputed each step.
|
| 63 |
-
pub num_active_connected: u32,
|
| 64 |
-
pub num_active_potential: u32,
|
| 65 |
-
/// Simple "last iter touched" stat for least-used cell selection.
|
| 66 |
-
pub last_used_iteration: u64,
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
pub struct TemporalMemoryConfig {
|
| 70 |
-
pub n_columns: usize,
|
| 71 |
-
pub cells_per_column: usize,
|
| 72 |
-
pub activation_threshold: u32,
|
| 73 |
-
pub learning_threshold: u32,
|
| 74 |
-
pub initial_permanence: f32,
|
| 75 |
-
pub connected_permanence: f32,
|
| 76 |
-
pub permanence_increment: f32,
|
| 77 |
-
pub permanence_decrement: f32,
|
| 78 |
-
pub predicted_segment_decrement: f32,
|
| 79 |
-
pub max_segments_per_cell: usize,
|
| 80 |
-
pub max_synapses_per_segment: usize,
|
| 81 |
-
pub max_new_synapse_count: usize,
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
impl Default for TemporalMemoryConfig {
|
| 85 |
-
fn default() -> Self {
|
| 86 |
-
Self {
|
| 87 |
-
n_columns: 2048,
|
| 88 |
-
cells_per_column: 32,
|
| 89 |
-
activation_threshold: 15,
|
| 90 |
-
learning_threshold: 13,
|
| 91 |
-
initial_permanence: 0.21,
|
| 92 |
-
connected_permanence: 0.50,
|
| 93 |
-
permanence_increment: 0.10,
|
| 94 |
-
permanence_decrement: 0.10,
|
| 95 |
-
predicted_segment_decrement: 0.10,
|
| 96 |
-
max_segments_per_cell: 255,
|
| 97 |
-
max_synapses_per_segment: 32,
|
| 98 |
-
max_new_synapse_count: 20,
|
| 99 |
-
}
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
pub struct TemporalMemory {
|
| 104 |
-
pub cfg: TemporalMemoryConfig,
|
| 105 |
-
/// All segments in the region. Indexed by SegmentIdx.
|
| 106 |
-
pub segments: Vec<Segment>,
|
| 107 |
-
/// For each cell, the list of segments that belong to it.
|
| 108 |
-
pub cell_segments: Vec<Vec<SegmentIdx>>,
|
| 109 |
-
/// Active cells in the current step.
|
| 110 |
-
pub active_cells: Vec<bool>,
|
| 111 |
-
/// Winner cells (subset of active_cells, 1 per active column) for learning.
|
| 112 |
-
pub winner_cells: Vec<bool>,
|
| 113 |
-
/// Predictive cells for the current step = cells whose segment became
|
| 114 |
-
/// active at the end of the previous step.
|
| 115 |
-
pub predictive_cells: Vec<bool>,
|
| 116 |
-
/// Cached list of segment indices that were "active" last compute().
|
| 117 |
-
active_segments_prev: Vec<SegmentIdx>,
|
| 118 |
-
/// Cached list of segment indices that were "matching" last compute().
|
| 119 |
-
matching_segments_prev: Vec<SegmentIdx>,
|
| 120 |
-
rng: Xoshiro256PlusPlus,
|
| 121 |
-
iter_count: u64,
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
impl TemporalMemory {
|
| 125 |
-
pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self {
|
| 126 |
-
let total = cfg.n_columns * cfg.cells_per_column;
|
| 127 |
-
Self {
|
| 128 |
-
cell_segments: vec![Vec::new(); total],
|
| 129 |
-
active_cells: vec![false; total],
|
| 130 |
-
winner_cells: vec![false; total],
|
| 131 |
-
predictive_cells: vec![false; total],
|
| 132 |
-
cfg,
|
| 133 |
-
segments: Vec::new(),
|
| 134 |
-
active_segments_prev: Vec::new(),
|
| 135 |
-
matching_segments_prev: Vec::new(),
|
| 136 |
-
rng: Xoshiro256PlusPlus::seed_from_u64(seed),
|
| 137 |
-
iter_count: 0,
|
| 138 |
-
}
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
pub fn reset(&mut self) {
|
| 142 |
-
for v in self.active_cells.iter_mut() { *v = false; }
|
| 143 |
-
for v in self.winner_cells.iter_mut() { *v = false; }
|
| 144 |
-
for v in self.predictive_cells.iter_mut() { *v = false; }
|
| 145 |
-
self.active_segments_prev.clear();
|
| 146 |
-
self.matching_segments_prev.clear();
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
#[inline]
|
| 150 |
-
fn col_of(&self, cell: CellIdx) -> usize {
|
| 151 |
-
(cell as usize) / self.cfg.cells_per_column
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
#[inline]
|
| 155 |
-
fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> {
|
| 156 |
-
let base = (col * self.cfg.cells_per_column) as CellIdx;
|
| 157 |
-
base..(base + self.cfg.cells_per_column as CellIdx)
|
| 158 |
-
}
|
| 159 |
-
|
| 160 |
-
/// Process one step.
|
| 161 |
-
///
|
| 162 |
-
/// `active_columns` is the set of column indices activated by the Spatial
|
| 163 |
-
/// Pooler this step. Returns the anomaly score in [0, 1].
|
| 164 |
-
pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 {
|
| 165 |
-
self.iter_count = self.iter_count.wrapping_add(1);
|
| 166 |
-
|
| 167 |
-
// Snapshot previous-step cell activity (for learning on segments).
|
| 168 |
-
let prev_active_cells = self.active_cells.clone();
|
| 169 |
-
let prev_winner_cells = self.winner_cells.clone();
|
| 170 |
-
|
| 171 |
-
// Move current "predictive" (computed at the end of the last step)
|
| 172 |
-
// into local variables; we'll overwrite predictive_cells later.
|
| 173 |
-
let predictive_prev = self.predictive_cells.clone();
|
| 174 |
-
|
| 175 |
-
// Group active segments and matching segments by column of their
|
| 176 |
-
// owning cell, for the columns that are active this step.
|
| 177 |
-
let n_cols = self.cfg.n_columns;
|
| 178 |
-
|
| 179 |
-
// active_segs_by_col[col] = segment indices whose cell is in col and
|
| 180 |
-
// which were "active" in the previous depolarization.
|
| 181 |
-
// matching_segs_by_col[col] = similarly for "matching".
|
| 182 |
-
let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
|
| 183 |
-
let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
|
| 184 |
-
for &seg in &self.active_segments_prev {
|
| 185 |
-
let col = self.col_of(self.segments[seg as usize].cell);
|
| 186 |
-
active_segs_by_col[col].push(seg);
|
| 187 |
-
}
|
| 188 |
-
for &seg in &self.matching_segments_prev {
|
| 189 |
-
let col = self.col_of(self.segments[seg as usize].cell);
|
| 190 |
-
matching_segs_by_col[col].push(seg);
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
// Columns that are active this step (for O(1) lookup).
|
| 194 |
-
let mut active_col_mask = vec![false; n_cols];
|
| 195 |
-
for &c in active_columns { active_col_mask[c as usize] = true; }
|
| 196 |
-
|
| 197 |
-
// Zero out current cell activations.
|
| 198 |
-
for v in self.active_cells.iter_mut() { *v = false; }
|
| 199 |
-
for v in self.winner_cells.iter_mut() { *v = false; }
|
| 200 |
-
|
| 201 |
-
// Track anomaly.
|
| 202 |
-
let mut unpredicted_cols = 0u32;
|
| 203 |
-
|
| 204 |
-
// We'll collect (segment, learn_mode) pairs for segment reinforcement
|
| 205 |
-
// so we can batch-apply permanence adjustments using prev_active_cells.
|
| 206 |
-
// learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched"
|
| 207 |
-
enum LearnOp {
|
| 208 |
-
Reinforce(SegmentIdx), // correctly predicted
|
| 209 |
-
Grow { // bursting column: grow on chosen segment
|
| 210 |
-
segment: SegmentIdx,
|
| 211 |
-
#[allow(dead_code)]
|
| 212 |
-
winner_cell: CellIdx,
|
| 213 |
-
},
|
| 214 |
-
Punish(SegmentIdx), // matching segment on inactive column
|
| 215 |
-
}
|
| 216 |
-
let mut ops: Vec<LearnOp> = Vec::new();
|
| 217 |
-
|
| 218 |
-
// ---- 1) Process active columns ----
|
| 219 |
-
for &col in active_columns {
|
| 220 |
-
let col = col as usize;
|
| 221 |
-
let active_segs = &active_segs_by_col[col];
|
| 222 |
-
if !active_segs.is_empty() {
|
| 223 |
-
// "Activate predicted column": each cell with an active segment
|
| 224 |
-
// becomes active and is a winner; reinforce that segment.
|
| 225 |
-
let mut seen_cells: Vec<CellIdx> = Vec::new();
|
| 226 |
-
for &seg_i in active_segs {
|
| 227 |
-
let seg = &self.segments[seg_i as usize];
|
| 228 |
-
let cell = seg.cell;
|
| 229 |
-
if !seen_cells.contains(&cell) {
|
| 230 |
-
self.active_cells[cell as usize] = true;
|
| 231 |
-
self.winner_cells[cell as usize] = true;
|
| 232 |
-
seen_cells.push(cell);
|
| 233 |
-
}
|
| 234 |
-
if learn {
|
| 235 |
-
ops.push(LearnOp::Reinforce(seg_i));
|
| 236 |
-
}
|
| 237 |
-
}
|
| 238 |
-
} else {
|
| 239 |
-
// ----- BURST -----
|
| 240 |
-
unpredicted_cols += 1;
|
| 241 |
-
for c in self.cells_in_col(col) {
|
| 242 |
-
self.active_cells[c as usize] = true;
|
| 243 |
-
}
|
| 244 |
-
// Pick a winner cell + segment for learning.
|
| 245 |
-
if learn {
|
| 246 |
-
let matching = &matching_segs_by_col[col];
|
| 247 |
-
let (winner_cell, target_segment) = if !matching.is_empty() {
|
| 248 |
-
// Best-matching segment = highest num_active_potential.
|
| 249 |
-
let mut best = matching[0];
|
| 250 |
-
let mut best_score = self.segments[best as usize].num_active_potential;
|
| 251 |
-
for &s in &matching[1..] {
|
| 252 |
-
let score = self.segments[s as usize].num_active_potential;
|
| 253 |
-
if score > best_score {
|
| 254 |
-
best_score = score;
|
| 255 |
-
best = s;
|
| 256 |
-
}
|
| 257 |
-
}
|
| 258 |
-
let wc = self.segments[best as usize].cell;
|
| 259 |
-
(wc, Some(best))
|
| 260 |
-
} else {
|
| 261 |
-
// Least-used cell in column, then grow a new segment.
|
| 262 |
-
let winner = self.least_used_cell(col);
|
| 263 |
-
(winner, None)
|
| 264 |
-
};
|
| 265 |
-
self.winner_cells[winner_cell as usize] = true;
|
| 266 |
-
let segment_id = match target_segment {
|
| 267 |
-
Some(s) => s,
|
| 268 |
-
None => {
|
| 269 |
-
// Create a fresh empty segment on winner cell.
|
| 270 |
-
self.create_segment(winner_cell)
|
| 271 |
-
}
|
| 272 |
-
};
|
| 273 |
-
ops.push(LearnOp::Grow { segment: segment_id, winner_cell });
|
| 274 |
-
} else {
|
| 275 |
-
// No learning: still pick some winner cell (arbitrary)
|
| 276 |
-
// so downstream code that inspects winner_cells isn't empty.
|
| 277 |
-
let matching = &matching_segs_by_col[col];
|
| 278 |
-
let winner_cell = if !matching.is_empty() {
|
| 279 |
-
self.segments[matching[0] as usize].cell
|
| 280 |
-
} else {
|
| 281 |
-
self.least_used_cell(col)
|
| 282 |
-
};
|
| 283 |
-
self.winner_cells[winner_cell as usize] = true;
|
| 284 |
-
}
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
// ---- 2) Punish matching segments on INACTIVE columns ----
|
| 289 |
-
if learn && self.cfg.predicted_segment_decrement > 0.0 {
|
| 290 |
-
for &seg_i in &self.matching_segments_prev {
|
| 291 |
-
let col = self.col_of(self.segments[seg_i as usize].cell);
|
| 292 |
-
if !active_col_mask[col] {
|
| 293 |
-
ops.push(LearnOp::Punish(seg_i));
|
| 294 |
-
}
|
| 295 |
-
}
|
| 296 |
-
}
|
| 297 |
-
|
| 298 |
-
// ---- 3) Apply learning ----
|
| 299 |
-
if learn {
|
| 300 |
-
for op in ops {
|
| 301 |
-
match op {
|
| 302 |
-
LearnOp::Reinforce(seg_i) => {
|
| 303 |
-
self.reinforce_segment(seg_i, &prev_active_cells);
|
| 304 |
-
// Optionally grow up to N new synapses to winner cells
|
| 305 |
-
// of the previous step.
|
| 306 |
-
self.grow_synapses_on_segment(seg_i, &prev_winner_cells);
|
| 307 |
-
}
|
| 308 |
-
LearnOp::Grow { segment, winner_cell: _ } => {
|
| 309 |
-
self.reinforce_segment(segment, &prev_active_cells);
|
| 310 |
-
self.grow_synapses_on_segment(segment, &prev_winner_cells);
|
| 311 |
-
}
|
| 312 |
-
LearnOp::Punish(seg_i) => {
|
| 313 |
-
let dec = self.cfg.predicted_segment_decrement;
|
| 314 |
-
for syn in &mut self.segments[seg_i as usize].synapses {
|
| 315 |
-
if prev_active_cells[syn.presynaptic_cell as usize] {
|
| 316 |
-
syn.permanence = (syn.permanence - dec).max(0.0);
|
| 317 |
-
}
|
| 318 |
-
}
|
| 319 |
-
}
|
| 320 |
-
}
|
| 321 |
-
}
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
-
// ---- 4) Compute segment activity & predictive cells for NEXT step ----
|
| 325 |
-
// We have to use the *current* active_cells (just set above).
|
| 326 |
-
let mut next_active_segs: Vec<SegmentIdx> = Vec::new();
|
| 327 |
-
let mut next_matching_segs: Vec<SegmentIdx> = Vec::new();
|
| 328 |
-
for v in self.predictive_cells.iter_mut() { *v = false; }
|
| 329 |
-
|
| 330 |
-
let conn = self.cfg.connected_permanence;
|
| 331 |
-
let act_thr = self.cfg.activation_threshold;
|
| 332 |
-
let learn_thr = self.cfg.learning_threshold;
|
| 333 |
-
|
| 334 |
-
for (seg_i, seg) in self.segments.iter_mut().enumerate() {
|
| 335 |
-
let mut n_conn: u32 = 0;
|
| 336 |
-
let mut n_pot: u32 = 0;
|
| 337 |
-
for syn in &seg.synapses {
|
| 338 |
-
if self.active_cells[syn.presynaptic_cell as usize] {
|
| 339 |
-
n_pot += 1;
|
| 340 |
-
if syn.permanence >= conn { n_conn += 1; }
|
| 341 |
-
}
|
| 342 |
-
}
|
| 343 |
-
seg.num_active_connected = n_conn;
|
| 344 |
-
seg.num_active_potential = n_pot;
|
| 345 |
-
if n_conn >= act_thr {
|
| 346 |
-
next_active_segs.push(seg_i as SegmentIdx);
|
| 347 |
-
self.predictive_cells[seg.cell as usize] = true;
|
| 348 |
-
}
|
| 349 |
-
if n_pot >= learn_thr {
|
| 350 |
-
next_matching_segs.push(seg_i as SegmentIdx);
|
| 351 |
-
}
|
| 352 |
-
}
|
| 353 |
-
self.active_segments_prev = next_active_segs;
|
| 354 |
-
self.matching_segments_prev = next_matching_segs;
|
| 355 |
-
|
| 356 |
-
// Keep predictive_prev unused-guard; we no longer need it but
|
| 357 |
-
// retained to document intent.
|
| 358 |
-
let _ = predictive_prev;
|
| 359 |
-
|
| 360 |
-
// Anomaly.
|
| 361 |
-
if active_columns.is_empty() {
|
| 362 |
-
0.0
|
| 363 |
-
} else {
|
| 364 |
-
(unpredicted_cols as f32) / (active_columns.len() as f32)
|
| 365 |
-
}
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
/// Reinforce synapses on `seg`: +inc if presynaptic is active last step,
|
| 369 |
-
/// -dec otherwise.
|
| 370 |
-
fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) {
|
| 371 |
-
let inc = self.cfg.permanence_increment;
|
| 372 |
-
let dec = self.cfg.permanence_decrement;
|
| 373 |
-
let seg = &mut self.segments[seg_i as usize];
|
| 374 |
-
seg.last_used_iteration = self.iter_count;
|
| 375 |
-
for syn in &mut seg.synapses {
|
| 376 |
-
if prev_active_cells[syn.presynaptic_cell as usize] {
|
| 377 |
-
syn.permanence = (syn.permanence + inc).min(1.0);
|
| 378 |
-
} else {
|
| 379 |
-
syn.permanence = (syn.permanence - dec).max(0.0);
|
| 380 |
-
}
|
| 381 |
-
}
|
| 382 |
-
}
|
| 383 |
-
|
| 384 |
-
/// Grow up to `max_new_synapse_count - current_potential` new synapses
|
| 385 |
-
/// from previous winner cells that are not already connected to this seg.
|
| 386 |
-
fn grow_synapses_on_segment(
|
| 387 |
-
&mut self,
|
| 388 |
-
seg_i: SegmentIdx,
|
| 389 |
-
prev_winner_cells: &[bool],
|
| 390 |
-
) {
|
| 391 |
-
let initial_perm = self.cfg.initial_permanence;
|
| 392 |
-
let cap = self.cfg.max_synapses_per_segment;
|
| 393 |
-
let max_new = self.cfg.max_new_synapse_count;
|
| 394 |
-
|
| 395 |
-
// Gather candidate cells (prev winners not already presynaptic to this seg).
|
| 396 |
-
let already: Vec<CellIdx> = self.segments[seg_i as usize]
|
| 397 |
-
.synapses
|
| 398 |
-
.iter()
|
| 399 |
-
.map(|s| s.presynaptic_cell)
|
| 400 |
-
.collect();
|
| 401 |
-
let mut candidates: Vec<CellIdx> = Vec::new();
|
| 402 |
-
for (cell_i, &b) in prev_winner_cells.iter().enumerate() {
|
| 403 |
-
if b && !already.contains(&(cell_i as CellIdx)) {
|
| 404 |
-
candidates.push(cell_i as CellIdx);
|
| 405 |
-
}
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
// How many can we add?
|
| 409 |
-
let current_len = self.segments[seg_i as usize].synapses.len();
|
| 410 |
-
let room = cap.saturating_sub(current_len);
|
| 411 |
-
let mut to_add = max_new.min(candidates.len()).min(room);
|
| 412 |
-
|
| 413 |
-
// Random sample without replacement from candidates.
|
| 414 |
-
while to_add > 0 {
|
| 415 |
-
let idx = self.rng.gen_range(0..candidates.len());
|
| 416 |
-
let pre = candidates.swap_remove(idx);
|
| 417 |
-
self.segments[seg_i as usize].synapses.push(Synapse {
|
| 418 |
-
presynaptic_cell: pre,
|
| 419 |
-
permanence: initial_perm,
|
| 420 |
-
});
|
| 421 |
-
to_add -= 1;
|
| 422 |
-
}
|
| 423 |
-
}
|
| 424 |
-
|
| 425 |
-
fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx {
|
| 426 |
-
// Enforce per-cell segment cap by evicting least-recently-used segment
|
| 427 |
-
// if necessary.
|
| 428 |
-
let cell_segs = &mut self.cell_segments[cell as usize];
|
| 429 |
-
if cell_segs.len() >= self.cfg.max_segments_per_cell {
|
| 430 |
-
// Find LRU segment.
|
| 431 |
-
let (lru_pos, &lru_id) = cell_segs
|
| 432 |
-
.iter()
|
| 433 |
-
.enumerate()
|
| 434 |
-
.min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration)
|
| 435 |
-
.expect("cell_segs non-empty");
|
| 436 |
-
// Clear that segment in place and reuse its index.
|
| 437 |
-
self.segments[lru_id as usize].synapses.clear();
|
| 438 |
-
self.segments[lru_id as usize].num_active_connected = 0;
|
| 439 |
-
self.segments[lru_id as usize].num_active_potential = 0;
|
| 440 |
-
self.segments[lru_id as usize].last_used_iteration = self.iter_count;
|
| 441 |
-
// Keep at same position in cell_segs.
|
| 442 |
-
let _ = lru_pos;
|
| 443 |
-
return lru_id;
|
| 444 |
-
}
|
| 445 |
-
|
| 446 |
-
let new_id = self.segments.len() as SegmentIdx;
|
| 447 |
-
self.segments.push(Segment {
|
| 448 |
-
cell,
|
| 449 |
-
synapses: Vec::with_capacity(self.cfg.max_new_synapse_count),
|
| 450 |
-
num_active_connected: 0,
|
| 451 |
-
num_active_potential: 0,
|
| 452 |
-
last_used_iteration: self.iter_count,
|
| 453 |
-
});
|
| 454 |
-
cell_segs.push(new_id);
|
| 455 |
-
new_id
|
| 456 |
-
}
|
| 457 |
-
|
| 458 |
-
fn least_used_cell(&mut self, col: usize) -> CellIdx {
|
| 459 |
-
// Cell with the fewest segments; break ties randomly.
|
| 460 |
-
let mut min_segs = usize::MAX;
|
| 461 |
-
let mut candidates: Vec<CellIdx> = Vec::new();
|
| 462 |
-
for c in self.cells_in_col(col) {
|
| 463 |
-
let n = self.cell_segments[c as usize].len();
|
| 464 |
-
if n < min_segs {
|
| 465 |
-
min_segs = n;
|
| 466 |
-
candidates.clear();
|
| 467 |
-
candidates.push(c);
|
| 468 |
-
} else if n == min_segs {
|
| 469 |
-
candidates.push(c);
|
| 470 |
-
}
|
| 471 |
-
}
|
| 472 |
-
let idx = self.rng.gen_range(0..candidates.len());
|
| 473 |
-
candidates[idx]
|
| 474 |
-
}
|
| 475 |
-
}
|
| 476 |
-
|
| 477 |
-
// ---------------------------------------------------------------------------
|
| 478 |
-
// Tests
|
| 479 |
-
// ---------------------------------------------------------------------------
|
| 480 |
-
|
| 481 |
-
#[cfg(test)]
|
| 482 |
-
mod tests {
|
| 483 |
-
use super::*;
|
| 484 |
-
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 485 |
-
use rand::Rng;
|
| 486 |
-
use rand::SeedableRng;
|
| 487 |
-
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 488 |
-
|
| 489 |
-
#[test]
|
| 490 |
-
fn tm_learns_repeating_sequence() {
|
| 491 |
-
// Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down.
|
| 492 |
-
let cfg = SpatialPoolerConfig::default();
|
| 493 |
-
let mut sp = SpatialPooler::new(cfg, 123);
|
| 494 |
-
let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456);
|
| 495 |
-
|
| 496 |
-
// Build 3 fixed random SDRs of 2% sparsity.
|
| 497 |
-
let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
|
| 498 |
-
let input_bits = sp.cfg.input_bits;
|
| 499 |
-
let make_sdr = |rng: &mut Xoshiro256PlusPlus| {
|
| 500 |
-
let mut v = vec![false; input_bits];
|
| 501 |
-
let on = (0.02 * input_bits as f32) as usize;
|
| 502 |
-
let mut placed = 0;
|
| 503 |
-
while placed < on {
|
| 504 |
-
let i = rng.gen_range(0..input_bits);
|
| 505 |
-
if !v[i] {
|
| 506 |
-
v[i] = true;
|
| 507 |
-
placed += 1;
|
| 508 |
-
}
|
| 509 |
-
}
|
| 510 |
-
v
|
| 511 |
-
};
|
| 512 |
-
let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)];
|
| 513 |
-
|
| 514 |
-
// Warm up SP first so that columns are reliable for each symbol.
|
| 515 |
-
for _ in 0..200 {
|
| 516 |
-
for s in &seqs {
|
| 517 |
-
sp.compute(s, true);
|
| 518 |
-
}
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
// Reset TM so prediction state is clean.
|
| 522 |
-
tm.reset();
|
| 523 |
-
|
| 524 |
-
// Record anomaly over a window early and late.
|
| 525 |
-
let mut early_anoms: Vec<f32> = Vec::new();
|
| 526 |
-
let mut late_anoms: Vec<f32> = Vec::new();
|
| 527 |
-
for iter in 0..250 {
|
| 528 |
-
for s in &seqs {
|
| 529 |
-
let active = sp.compute(s, false);
|
| 530 |
-
let anomaly = tm.compute(&active, true);
|
| 531 |
-
if iter == 10 { early_anoms.push(anomaly); }
|
| 532 |
-
if iter == 249 { late_anoms.push(anomaly); }
|
| 533 |
-
}
|
| 534 |
-
}
|
| 535 |
-
|
| 536 |
-
let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32);
|
| 537 |
-
let early = mean(&early_anoms);
|
| 538 |
-
let late = mean(&late_anoms);
|
| 539 |
-
println!("early_anomaly={early}, late_anomaly={late}");
|
| 540 |
-
assert!(
|
| 541 |
-
late < 0.5 * early + 1e-6,
|
| 542 |
-
"late anomaly ({late}) should be < 0.5 * early anomaly ({early})"
|
| 543 |
-
);
|
| 544 |
-
}
|
| 545 |
-
}
|
|
|
|
| 1 |
+
//! Numenta BAMI-spec Temporal Memory.
|
| 2 |
+
//!
|
| 3 |
+
//! Key parameters (Numenta defaults):
|
| 4 |
+
//! - cells_per_column = 32
|
| 5 |
+
//! - max_segments_per_cell = 255
|
| 6 |
+
//! - max_synapses_per_segment = 32
|
| 7 |
+
//! - activation_threshold = 15 (CONNECTED synapses onto active cells)
|
| 8 |
+
//! - learning_threshold = 13 (POTENTIAL synapses onto active cells)
|
| 9 |
+
//! (often called `minThreshold` / match threshold in BAMI)
|
| 10 |
+
//! - initial_permanence = 0.21
|
| 11 |
+
//! - connected_permanence = 0.50
|
| 12 |
+
//! - permanence_increment = 0.10
|
| 13 |
+
//! - permanence_decrement = 0.10
|
| 14 |
+
//! - predicted_segment_decrement = 0.10 (decay for segments that predicted
|
| 15 |
+
//! inactive columns; called `predictedSegmentDecrement` in BAMI)
|
| 16 |
+
//! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg)
|
| 17 |
+
//!
|
| 18 |
+
//! Algorithm (one step):
|
| 19 |
+
//! Given `active_columns` from the Spatial Pooler, and segment activity
|
| 20 |
+
//! caches `active_segments` and `matching_segments` computed *at the end of
|
| 21 |
+
//! the previous step*:
|
| 22 |
+
//!
|
| 23 |
+
//! 1. For each active column:
|
| 24 |
+
//! - If it contains any predicted cell (any cell with an active segment
|
| 25 |
+
//! from the previous depolarization), mark those cells active and
|
| 26 |
+
//! learn on the segment that predicted it.
|
| 27 |
+
//! - Else BURST the column: mark all cells in it active, and grow a new
|
| 28 |
+
//! segment on the best-matching cell in the column (or, if none,
|
| 29 |
+
//! on the cell with the fewest segments).
|
| 30 |
+
//! 2. For every column that was predicted but did NOT become active
|
| 31 |
+
//! (matching segments on inactive columns), apply the
|
| 32 |
+
//! `predicted_segment_decrement` decay so spurious predictions fade.
|
| 33 |
+
//! 3. Winner cells = active cells chosen for learning (1 per active column).
|
| 34 |
+
//! 4. Compute segment activity for NEXT step:
|
| 35 |
+
//! - A segment's CONNECTED activity = #synapses with perm >= connected_perm
|
| 36 |
+
//! whose presynaptic cell is in `active_cells`. If >= activation_threshold
|
| 37 |
+
//! -> segment is "active" -> its cell is "predicted".
|
| 38 |
+
//! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is
|
| 39 |
+
//! in `active_cells` (regardless of permanence). If >= learning_threshold
|
| 40 |
+
//! -> segment is "matching".
|
| 41 |
+
//!
|
| 42 |
+
//! Anomaly score = (active columns with no prior predicted cells)
|
| 43 |
+
//! / (# active columns).
|
| 44 |
+
|
| 45 |
+
use rand::Rng;
|
| 46 |
+
use rand::SeedableRng;
|
| 47 |
+
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 48 |
+
|
| 49 |
+
type CellIdx = u32;
|
| 50 |
+
type SegmentIdx = u32;
|
| 51 |
+
|
| 52 |
+
#[derive(Clone)]
|
| 53 |
+
pub struct Synapse {
|
| 54 |
+
pub presynaptic_cell: CellIdx,
|
| 55 |
+
pub permanence: f32,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
#[derive(Clone)]
|
| 59 |
+
pub struct Segment {
|
| 60 |
+
pub cell: CellIdx,
|
| 61 |
+
pub synapses: Vec<Synapse>,
|
| 62 |
+
/// Cached counters; recomputed each step.
|
| 63 |
+
pub num_active_connected: u32,
|
| 64 |
+
pub num_active_potential: u32,
|
| 65 |
+
/// Simple "last iter touched" stat for least-used cell selection.
|
| 66 |
+
pub last_used_iteration: u64,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
pub struct TemporalMemoryConfig {
|
| 70 |
+
pub n_columns: usize,
|
| 71 |
+
pub cells_per_column: usize,
|
| 72 |
+
pub activation_threshold: u32,
|
| 73 |
+
pub learning_threshold: u32,
|
| 74 |
+
pub initial_permanence: f32,
|
| 75 |
+
pub connected_permanence: f32,
|
| 76 |
+
pub permanence_increment: f32,
|
| 77 |
+
pub permanence_decrement: f32,
|
| 78 |
+
pub predicted_segment_decrement: f32,
|
| 79 |
+
pub max_segments_per_cell: usize,
|
| 80 |
+
pub max_synapses_per_segment: usize,
|
| 81 |
+
pub max_new_synapse_count: usize,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
impl Default for TemporalMemoryConfig {
|
| 85 |
+
fn default() -> Self {
|
| 86 |
+
Self {
|
| 87 |
+
n_columns: 2048,
|
| 88 |
+
cells_per_column: 32,
|
| 89 |
+
activation_threshold: 15,
|
| 90 |
+
learning_threshold: 13,
|
| 91 |
+
initial_permanence: 0.21,
|
| 92 |
+
connected_permanence: 0.50,
|
| 93 |
+
permanence_increment: 0.10,
|
| 94 |
+
permanence_decrement: 0.10,
|
| 95 |
+
predicted_segment_decrement: 0.10,
|
| 96 |
+
max_segments_per_cell: 255,
|
| 97 |
+
max_synapses_per_segment: 32,
|
| 98 |
+
max_new_synapse_count: 20,
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
pub struct TemporalMemory {
|
| 104 |
+
pub cfg: TemporalMemoryConfig,
|
| 105 |
+
/// All segments in the region. Indexed by SegmentIdx.
|
| 106 |
+
pub segments: Vec<Segment>,
|
| 107 |
+
/// For each cell, the list of segments that belong to it.
|
| 108 |
+
pub cell_segments: Vec<Vec<SegmentIdx>>,
|
| 109 |
+
/// Active cells in the current step.
|
| 110 |
+
pub active_cells: Vec<bool>,
|
| 111 |
+
/// Winner cells (subset of active_cells, 1 per active column) for learning.
|
| 112 |
+
pub winner_cells: Vec<bool>,
|
| 113 |
+
/// Predictive cells for the current step = cells whose segment became
|
| 114 |
+
/// active at the end of the previous step.
|
| 115 |
+
pub predictive_cells: Vec<bool>,
|
| 116 |
+
/// Cached list of segment indices that were "active" last compute().
|
| 117 |
+
active_segments_prev: Vec<SegmentIdx>,
|
| 118 |
+
/// Cached list of segment indices that were "matching" last compute().
|
| 119 |
+
matching_segments_prev: Vec<SegmentIdx>,
|
| 120 |
+
rng: Xoshiro256PlusPlus,
|
| 121 |
+
iter_count: u64,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
impl TemporalMemory {
|
| 125 |
+
pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self {
|
| 126 |
+
let total = cfg.n_columns * cfg.cells_per_column;
|
| 127 |
+
Self {
|
| 128 |
+
cell_segments: vec![Vec::new(); total],
|
| 129 |
+
active_cells: vec![false; total],
|
| 130 |
+
winner_cells: vec![false; total],
|
| 131 |
+
predictive_cells: vec![false; total],
|
| 132 |
+
cfg,
|
| 133 |
+
segments: Vec::new(),
|
| 134 |
+
active_segments_prev: Vec::new(),
|
| 135 |
+
matching_segments_prev: Vec::new(),
|
| 136 |
+
rng: Xoshiro256PlusPlus::seed_from_u64(seed),
|
| 137 |
+
iter_count: 0,
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
pub fn reset(&mut self) {
|
| 142 |
+
for v in self.active_cells.iter_mut() { *v = false; }
|
| 143 |
+
for v in self.winner_cells.iter_mut() { *v = false; }
|
| 144 |
+
for v in self.predictive_cells.iter_mut() { *v = false; }
|
| 145 |
+
self.active_segments_prev.clear();
|
| 146 |
+
self.matching_segments_prev.clear();
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
#[inline]
|
| 150 |
+
fn col_of(&self, cell: CellIdx) -> usize {
|
| 151 |
+
(cell as usize) / self.cfg.cells_per_column
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
#[inline]
|
| 155 |
+
fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> {
|
| 156 |
+
let base = (col * self.cfg.cells_per_column) as CellIdx;
|
| 157 |
+
base..(base + self.cfg.cells_per_column as CellIdx)
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Process one step.
|
| 161 |
+
///
|
| 162 |
+
/// `active_columns` is the set of column indices activated by the Spatial
|
| 163 |
+
/// Pooler this step. Returns the anomaly score in [0, 1].
|
| 164 |
+
pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 {
|
| 165 |
+
self.iter_count = self.iter_count.wrapping_add(1);
|
| 166 |
+
|
| 167 |
+
// Snapshot previous-step cell activity (for learning on segments).
|
| 168 |
+
let prev_active_cells = self.active_cells.clone();
|
| 169 |
+
let prev_winner_cells = self.winner_cells.clone();
|
| 170 |
+
|
| 171 |
+
// Move current "predictive" (computed at the end of the last step)
|
| 172 |
+
// into local variables; we'll overwrite predictive_cells later.
|
| 173 |
+
let predictive_prev = self.predictive_cells.clone();
|
| 174 |
+
|
| 175 |
+
// Group active segments and matching segments by column of their
|
| 176 |
+
// owning cell, for the columns that are active this step.
|
| 177 |
+
let n_cols = self.cfg.n_columns;
|
| 178 |
+
|
| 179 |
+
// active_segs_by_col[col] = segment indices whose cell is in col and
|
| 180 |
+
// which were "active" in the previous depolarization.
|
| 181 |
+
// matching_segs_by_col[col] = similarly for "matching".
|
| 182 |
+
let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
|
| 183 |
+
let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
|
| 184 |
+
for &seg in &self.active_segments_prev {
|
| 185 |
+
let col = self.col_of(self.segments[seg as usize].cell);
|
| 186 |
+
active_segs_by_col[col].push(seg);
|
| 187 |
+
}
|
| 188 |
+
for &seg in &self.matching_segments_prev {
|
| 189 |
+
let col = self.col_of(self.segments[seg as usize].cell);
|
| 190 |
+
matching_segs_by_col[col].push(seg);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// Columns that are active this step (for O(1) lookup).
|
| 194 |
+
let mut active_col_mask = vec![false; n_cols];
|
| 195 |
+
for &c in active_columns { active_col_mask[c as usize] = true; }
|
| 196 |
+
|
| 197 |
+
// Zero out current cell activations.
|
| 198 |
+
for v in self.active_cells.iter_mut() { *v = false; }
|
| 199 |
+
for v in self.winner_cells.iter_mut() { *v = false; }
|
| 200 |
+
|
| 201 |
+
// Track anomaly.
|
| 202 |
+
let mut unpredicted_cols = 0u32;
|
| 203 |
+
|
| 204 |
+
// We'll collect (segment, learn_mode) pairs for segment reinforcement
|
| 205 |
+
// so we can batch-apply permanence adjustments using prev_active_cells.
|
| 206 |
+
// learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched"
|
| 207 |
+
enum LearnOp {
|
| 208 |
+
Reinforce(SegmentIdx), // correctly predicted
|
| 209 |
+
Grow { // bursting column: grow on chosen segment
|
| 210 |
+
segment: SegmentIdx,
|
| 211 |
+
#[allow(dead_code)]
|
| 212 |
+
winner_cell: CellIdx,
|
| 213 |
+
},
|
| 214 |
+
Punish(SegmentIdx), // matching segment on inactive column
|
| 215 |
+
}
|
| 216 |
+
let mut ops: Vec<LearnOp> = Vec::new();
|
| 217 |
+
|
| 218 |
+
// ---- 1) Process active columns ----
|
| 219 |
+
for &col in active_columns {
|
| 220 |
+
let col = col as usize;
|
| 221 |
+
let active_segs = &active_segs_by_col[col];
|
| 222 |
+
if !active_segs.is_empty() {
|
| 223 |
+
// "Activate predicted column": each cell with an active segment
|
| 224 |
+
// becomes active and is a winner; reinforce that segment.
|
| 225 |
+
let mut seen_cells: Vec<CellIdx> = Vec::new();
|
| 226 |
+
for &seg_i in active_segs {
|
| 227 |
+
let seg = &self.segments[seg_i as usize];
|
| 228 |
+
let cell = seg.cell;
|
| 229 |
+
if !seen_cells.contains(&cell) {
|
| 230 |
+
self.active_cells[cell as usize] = true;
|
| 231 |
+
self.winner_cells[cell as usize] = true;
|
| 232 |
+
seen_cells.push(cell);
|
| 233 |
+
}
|
| 234 |
+
if learn {
|
| 235 |
+
ops.push(LearnOp::Reinforce(seg_i));
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
} else {
|
| 239 |
+
// ----- BURST -----
|
| 240 |
+
unpredicted_cols += 1;
|
| 241 |
+
for c in self.cells_in_col(col) {
|
| 242 |
+
self.active_cells[c as usize] = true;
|
| 243 |
+
}
|
| 244 |
+
// Pick a winner cell + segment for learning.
|
| 245 |
+
if learn {
|
| 246 |
+
let matching = &matching_segs_by_col[col];
|
| 247 |
+
let (winner_cell, target_segment) = if !matching.is_empty() {
|
| 248 |
+
// Best-matching segment = highest num_active_potential.
|
| 249 |
+
let mut best = matching[0];
|
| 250 |
+
let mut best_score = self.segments[best as usize].num_active_potential;
|
| 251 |
+
for &s in &matching[1..] {
|
| 252 |
+
let score = self.segments[s as usize].num_active_potential;
|
| 253 |
+
if score > best_score {
|
| 254 |
+
best_score = score;
|
| 255 |
+
best = s;
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
let wc = self.segments[best as usize].cell;
|
| 259 |
+
(wc, Some(best))
|
| 260 |
+
} else {
|
| 261 |
+
// Least-used cell in column, then grow a new segment.
|
| 262 |
+
let winner = self.least_used_cell(col);
|
| 263 |
+
(winner, None)
|
| 264 |
+
};
|
| 265 |
+
self.winner_cells[winner_cell as usize] = true;
|
| 266 |
+
let segment_id = match target_segment {
|
| 267 |
+
Some(s) => s,
|
| 268 |
+
None => {
|
| 269 |
+
// Create a fresh empty segment on winner cell.
|
| 270 |
+
self.create_segment(winner_cell)
|
| 271 |
+
}
|
| 272 |
+
};
|
| 273 |
+
ops.push(LearnOp::Grow { segment: segment_id, winner_cell });
|
| 274 |
+
} else {
|
| 275 |
+
// No learning: still pick some winner cell (arbitrary)
|
| 276 |
+
// so downstream code that inspects winner_cells isn't empty.
|
| 277 |
+
let matching = &matching_segs_by_col[col];
|
| 278 |
+
let winner_cell = if !matching.is_empty() {
|
| 279 |
+
self.segments[matching[0] as usize].cell
|
| 280 |
+
} else {
|
| 281 |
+
self.least_used_cell(col)
|
| 282 |
+
};
|
| 283 |
+
self.winner_cells[winner_cell as usize] = true;
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// ---- 2) Punish matching segments on INACTIVE columns ----
|
| 289 |
+
if learn && self.cfg.predicted_segment_decrement > 0.0 {
|
| 290 |
+
for &seg_i in &self.matching_segments_prev {
|
| 291 |
+
let col = self.col_of(self.segments[seg_i as usize].cell);
|
| 292 |
+
if !active_col_mask[col] {
|
| 293 |
+
ops.push(LearnOp::Punish(seg_i));
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// ---- 3) Apply learning ----
|
| 299 |
+
if learn {
|
| 300 |
+
for op in ops {
|
| 301 |
+
match op {
|
| 302 |
+
LearnOp::Reinforce(seg_i) => {
|
| 303 |
+
self.reinforce_segment(seg_i, &prev_active_cells);
|
| 304 |
+
// Optionally grow up to N new synapses to winner cells
|
| 305 |
+
// of the previous step.
|
| 306 |
+
self.grow_synapses_on_segment(seg_i, &prev_winner_cells);
|
| 307 |
+
}
|
| 308 |
+
LearnOp::Grow { segment, winner_cell: _ } => {
|
| 309 |
+
self.reinforce_segment(segment, &prev_active_cells);
|
| 310 |
+
self.grow_synapses_on_segment(segment, &prev_winner_cells);
|
| 311 |
+
}
|
| 312 |
+
LearnOp::Punish(seg_i) => {
|
| 313 |
+
let dec = self.cfg.predicted_segment_decrement;
|
| 314 |
+
for syn in &mut self.segments[seg_i as usize].synapses {
|
| 315 |
+
if prev_active_cells[syn.presynaptic_cell as usize] {
|
| 316 |
+
syn.permanence = (syn.permanence - dec).max(0.0);
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// ---- 4) Compute segment activity & predictive cells for NEXT step ----
|
| 325 |
+
// We have to use the *current* active_cells (just set above).
|
| 326 |
+
let mut next_active_segs: Vec<SegmentIdx> = Vec::new();
|
| 327 |
+
let mut next_matching_segs: Vec<SegmentIdx> = Vec::new();
|
| 328 |
+
for v in self.predictive_cells.iter_mut() { *v = false; }
|
| 329 |
+
|
| 330 |
+
let conn = self.cfg.connected_permanence;
|
| 331 |
+
let act_thr = self.cfg.activation_threshold;
|
| 332 |
+
let learn_thr = self.cfg.learning_threshold;
|
| 333 |
+
|
| 334 |
+
for (seg_i, seg) in self.segments.iter_mut().enumerate() {
|
| 335 |
+
let mut n_conn: u32 = 0;
|
| 336 |
+
let mut n_pot: u32 = 0;
|
| 337 |
+
for syn in &seg.synapses {
|
| 338 |
+
if self.active_cells[syn.presynaptic_cell as usize] {
|
| 339 |
+
n_pot += 1;
|
| 340 |
+
if syn.permanence >= conn { n_conn += 1; }
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
seg.num_active_connected = n_conn;
|
| 344 |
+
seg.num_active_potential = n_pot;
|
| 345 |
+
if n_conn >= act_thr {
|
| 346 |
+
next_active_segs.push(seg_i as SegmentIdx);
|
| 347 |
+
self.predictive_cells[seg.cell as usize] = true;
|
| 348 |
+
}
|
| 349 |
+
if n_pot >= learn_thr {
|
| 350 |
+
next_matching_segs.push(seg_i as SegmentIdx);
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
self.active_segments_prev = next_active_segs;
|
| 354 |
+
self.matching_segments_prev = next_matching_segs;
|
| 355 |
+
|
| 356 |
+
// Keep predictive_prev unused-guard; we no longer need it but
|
| 357 |
+
// retained to document intent.
|
| 358 |
+
let _ = predictive_prev;
|
| 359 |
+
|
| 360 |
+
// Anomaly.
|
| 361 |
+
if active_columns.is_empty() {
|
| 362 |
+
0.0
|
| 363 |
+
} else {
|
| 364 |
+
(unpredicted_cols as f32) / (active_columns.len() as f32)
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
/// Reinforce synapses on `seg`: +inc if presynaptic is active last step,
|
| 369 |
+
/// -dec otherwise.
|
| 370 |
+
fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) {
|
| 371 |
+
let inc = self.cfg.permanence_increment;
|
| 372 |
+
let dec = self.cfg.permanence_decrement;
|
| 373 |
+
let seg = &mut self.segments[seg_i as usize];
|
| 374 |
+
seg.last_used_iteration = self.iter_count;
|
| 375 |
+
for syn in &mut seg.synapses {
|
| 376 |
+
if prev_active_cells[syn.presynaptic_cell as usize] {
|
| 377 |
+
syn.permanence = (syn.permanence + inc).min(1.0);
|
| 378 |
+
} else {
|
| 379 |
+
syn.permanence = (syn.permanence - dec).max(0.0);
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
/// Grow up to `max_new_synapse_count - current_potential` new synapses
|
| 385 |
+
/// from previous winner cells that are not already connected to this seg.
|
| 386 |
+
fn grow_synapses_on_segment(
|
| 387 |
+
&mut self,
|
| 388 |
+
seg_i: SegmentIdx,
|
| 389 |
+
prev_winner_cells: &[bool],
|
| 390 |
+
) {
|
| 391 |
+
let initial_perm = self.cfg.initial_permanence;
|
| 392 |
+
let cap = self.cfg.max_synapses_per_segment;
|
| 393 |
+
let max_new = self.cfg.max_new_synapse_count;
|
| 394 |
+
|
| 395 |
+
// Gather candidate cells (prev winners not already presynaptic to this seg).
|
| 396 |
+
let already: Vec<CellIdx> = self.segments[seg_i as usize]
|
| 397 |
+
.synapses
|
| 398 |
+
.iter()
|
| 399 |
+
.map(|s| s.presynaptic_cell)
|
| 400 |
+
.collect();
|
| 401 |
+
let mut candidates: Vec<CellIdx> = Vec::new();
|
| 402 |
+
for (cell_i, &b) in prev_winner_cells.iter().enumerate() {
|
| 403 |
+
if b && !already.contains(&(cell_i as CellIdx)) {
|
| 404 |
+
candidates.push(cell_i as CellIdx);
|
| 405 |
+
}
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
// How many can we add?
|
| 409 |
+
let current_len = self.segments[seg_i as usize].synapses.len();
|
| 410 |
+
let room = cap.saturating_sub(current_len);
|
| 411 |
+
let mut to_add = max_new.min(candidates.len()).min(room);
|
| 412 |
+
|
| 413 |
+
// Random sample without replacement from candidates.
|
| 414 |
+
while to_add > 0 {
|
| 415 |
+
let idx = self.rng.gen_range(0..candidates.len());
|
| 416 |
+
let pre = candidates.swap_remove(idx);
|
| 417 |
+
self.segments[seg_i as usize].synapses.push(Synapse {
|
| 418 |
+
presynaptic_cell: pre,
|
| 419 |
+
permanence: initial_perm,
|
| 420 |
+
});
|
| 421 |
+
to_add -= 1;
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx {
|
| 426 |
+
// Enforce per-cell segment cap by evicting least-recently-used segment
|
| 427 |
+
// if necessary.
|
| 428 |
+
let cell_segs = &mut self.cell_segments[cell as usize];
|
| 429 |
+
if cell_segs.len() >= self.cfg.max_segments_per_cell {
|
| 430 |
+
// Find LRU segment.
|
| 431 |
+
let (lru_pos, &lru_id) = cell_segs
|
| 432 |
+
.iter()
|
| 433 |
+
.enumerate()
|
| 434 |
+
.min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration)
|
| 435 |
+
.expect("cell_segs non-empty");
|
| 436 |
+
// Clear that segment in place and reuse its index.
|
| 437 |
+
self.segments[lru_id as usize].synapses.clear();
|
| 438 |
+
self.segments[lru_id as usize].num_active_connected = 0;
|
| 439 |
+
self.segments[lru_id as usize].num_active_potential = 0;
|
| 440 |
+
self.segments[lru_id as usize].last_used_iteration = self.iter_count;
|
| 441 |
+
// Keep at same position in cell_segs.
|
| 442 |
+
let _ = lru_pos;
|
| 443 |
+
return lru_id;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
let new_id = self.segments.len() as SegmentIdx;
|
| 447 |
+
self.segments.push(Segment {
|
| 448 |
+
cell,
|
| 449 |
+
synapses: Vec::with_capacity(self.cfg.max_new_synapse_count),
|
| 450 |
+
num_active_connected: 0,
|
| 451 |
+
num_active_potential: 0,
|
| 452 |
+
last_used_iteration: self.iter_count,
|
| 453 |
+
});
|
| 454 |
+
cell_segs.push(new_id);
|
| 455 |
+
new_id
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
fn least_used_cell(&mut self, col: usize) -> CellIdx {
|
| 459 |
+
// Cell with the fewest segments; break ties randomly.
|
| 460 |
+
let mut min_segs = usize::MAX;
|
| 461 |
+
let mut candidates: Vec<CellIdx> = Vec::new();
|
| 462 |
+
for c in self.cells_in_col(col) {
|
| 463 |
+
let n = self.cell_segments[c as usize].len();
|
| 464 |
+
if n < min_segs {
|
| 465 |
+
min_segs = n;
|
| 466 |
+
candidates.clear();
|
| 467 |
+
candidates.push(c);
|
| 468 |
+
} else if n == min_segs {
|
| 469 |
+
candidates.push(c);
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
let idx = self.rng.gen_range(0..candidates.len());
|
| 473 |
+
candidates[idx]
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
// ---------------------------------------------------------------------------
|
| 478 |
+
// Tests
|
| 479 |
+
// ---------------------------------------------------------------------------
|
| 480 |
+
|
| 481 |
+
#[cfg(test)]
|
| 482 |
+
mod tests {
|
| 483 |
+
use super::*;
|
| 484 |
+
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 485 |
+
use rand::Rng;
|
| 486 |
+
use rand::SeedableRng;
|
| 487 |
+
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 488 |
+
|
| 489 |
+
#[test]
|
| 490 |
+
fn tm_learns_repeating_sequence() {
|
| 491 |
+
// Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down.
|
| 492 |
+
let cfg = SpatialPoolerConfig::default();
|
| 493 |
+
let mut sp = SpatialPooler::new(cfg, 123);
|
| 494 |
+
let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456);
|
| 495 |
+
|
| 496 |
+
// Build 3 fixed random SDRs of 2% sparsity.
|
| 497 |
+
let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
|
| 498 |
+
let input_bits = sp.cfg.input_bits;
|
| 499 |
+
let make_sdr = |rng: &mut Xoshiro256PlusPlus| {
|
| 500 |
+
let mut v = vec![false; input_bits];
|
| 501 |
+
let on = (0.02 * input_bits as f32) as usize;
|
| 502 |
+
let mut placed = 0;
|
| 503 |
+
while placed < on {
|
| 504 |
+
let i = rng.gen_range(0..input_bits);
|
| 505 |
+
if !v[i] {
|
| 506 |
+
v[i] = true;
|
| 507 |
+
placed += 1;
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
v
|
| 511 |
+
};
|
| 512 |
+
let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)];
|
| 513 |
+
|
| 514 |
+
// Warm up SP first so that columns are reliable for each symbol.
|
| 515 |
+
for _ in 0..200 {
|
| 516 |
+
for s in &seqs {
|
| 517 |
+
sp.compute(s, true);
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
// Reset TM so prediction state is clean.
|
| 522 |
+
tm.reset();
|
| 523 |
+
|
| 524 |
+
// Record anomaly over a window early and late.
|
| 525 |
+
let mut early_anoms: Vec<f32> = Vec::new();
|
| 526 |
+
let mut late_anoms: Vec<f32> = Vec::new();
|
| 527 |
+
for iter in 0..250 {
|
| 528 |
+
for s in &seqs {
|
| 529 |
+
let active = sp.compute(s, false);
|
| 530 |
+
let anomaly = tm.compute(&active, true);
|
| 531 |
+
if iter == 10 { early_anoms.push(anomaly); }
|
| 532 |
+
if iter == 249 { late_anoms.push(anomaly); }
|
| 533 |
+
}
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32);
|
| 537 |
+
let early = mean(&early_anoms);
|
| 538 |
+
let late = mean(&late_anoms);
|
| 539 |
+
println!("early_anomaly={early}, late_anomaly={late}");
|
| 540 |
+
assert!(
|
| 541 |
+
late < 0.5 * early + 1e-6,
|
| 542 |
+
"late anomaly ({late}) should be < 0.5 * early anomaly ({early})"
|
| 543 |
+
);
|
| 544 |
+
}
|
| 545 |
+
}
|
overlay/hydra/__init__.py
CHANGED
|
@@ -1,31 +1,37 @@
|
|
| 1 |
-
"""HYDRA training package.
|
| 2 |
-
|
| 3 |
-
Thin facade re-exporting the public API used by train.py, the test suite,
|
| 4 |
-
and external research scripts. Imports are lazy where possible to keep
|
| 5 |
-
`import hydra` cheap (prepare.py and mamba-ssm are the heavy deps).
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from hydra.config import PostSemClawConfig
|
| 9 |
-
from hydra.engram import GPUEngram
|
| 10 |
-
from hydra.
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
#
|
| 14 |
-
#
|
| 15 |
-
def __getattr__(name: str):
|
| 16 |
-
if name == "
|
| 17 |
-
from hydra.
|
| 18 |
-
return
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HYDRA training package.
|
| 2 |
+
|
| 3 |
+
Thin facade re-exporting the public API used by train.py, the test suite,
|
| 4 |
+
and external research scripts. Imports are lazy where possible to keep
|
| 5 |
+
`import hydra` cheap (prepare.py and mamba-ssm are the heavy deps).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from hydra.config import PostSemClawConfig
|
| 9 |
+
from hydra.engram import GPUEngram
|
| 10 |
+
from hydra.optimizer import MuonAdamW, adamw_step_fused, muon_step_fused
|
| 11 |
+
|
| 12 |
+
# Heavy imports are resolved lazily so `import hydra` and `import hydra.hyena_block`
|
| 13 |
+
# keep working in local CPU/test environments that do not have the container-only
|
| 14 |
+
# mamba-ssm wheel stack installed.
|
| 15 |
+
def __getattr__(name: str):
|
| 16 |
+
if name == "PostSemClawModel":
|
| 17 |
+
from hydra.model import PostSemClawModel as _model
|
| 18 |
+
return _model
|
| 19 |
+
if name == "norm":
|
| 20 |
+
from hydra.model import norm as _norm
|
| 21 |
+
return _norm
|
| 22 |
+
if name == "config_from_dict":
|
| 23 |
+
from hydra.training import config_from_dict as _cfd
|
| 24 |
+
return _cfd
|
| 25 |
+
raise AttributeError(name)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"PostSemClawConfig",
|
| 30 |
+
"GPUEngram",
|
| 31 |
+
"PostSemClawModel",
|
| 32 |
+
"norm",
|
| 33 |
+
"MuonAdamW",
|
| 34 |
+
"adamw_step_fused",
|
| 35 |
+
"muon_step_fused",
|
| 36 |
+
"config_from_dict",
|
| 37 |
+
]
|
overlay/hydra/config.py
CHANGED
|
@@ -1,220 +1,225 @@
|
|
| 1 |
-
"""HYDRA training configuration — dataclass + env-var constants.
|
| 2 |
-
|
| 3 |
-
Extracted from the monolithic train.py as part of W1 modularization. All
|
| 4 |
-
env-var reads and the PostSemClawConfig dataclass live here. The training
|
| 5 |
-
body imports these constants; zero behavior change from the extraction.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
|
| 10 |
-
import os
|
| 11 |
-
from dataclasses import dataclass, field
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def _parse_hyena_layers_env() -> tuple[int, ...]:
|
| 15 |
-
"""Parse HYDRA_HYENA_LAYERS env var into a sorted tuple of layer indices.
|
| 16 |
-
|
| 17 |
-
Used as the default_factory for PostSemClawConfig.hyena_layers so a fresh
|
| 18 |
-
config construction reads the current env var, but once constructed the
|
| 19 |
-
value is first-class and travels with checkpoints (see asdict(config) in
|
| 20 |
-
save_ckpt). Ckpt-load sets the dataclass field explicitly, overriding the
|
| 21 |
-
env-var default.
|
| 22 |
-
|
| 23 |
-
Returns empty tuple when env var is unset/empty (byte-identical to
|
| 24 |
-
pre-port behavior: no Hyena layers).
|
| 25 |
-
"""
|
| 26 |
-
raw = os.environ.get("HYDRA_HYENA_LAYERS", "")
|
| 27 |
-
if not raw:
|
| 28 |
-
return ()
|
| 29 |
-
return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _parse_gdn_layers_env() -> tuple[int, ...]:
|
| 33 |
-
"""Parse HYDRA_GDN_LAYERS env var into a sorted tuple of layer indices.
|
| 34 |
-
|
| 35 |
-
Same contract as _parse_hyena_layers_env: layers whose index is listed
|
| 36 |
-
here use GatedDeltaNet (fla.layers.GatedDeltaNet) as a drop-in
|
| 37 |
-
replacement for Mamba3. Empty tuple = no GDN layers (byte-identical
|
| 38 |
-
to baseline).
|
| 39 |
-
"""
|
| 40 |
-
raw = os.environ.get("HYDRA_GDN_LAYERS", "")
|
| 41 |
-
if not raw:
|
| 42 |
-
return ()
|
| 43 |
-
return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
|
| 44 |
-
|
| 45 |
-
# ---------------------------------------------------------------------------
|
| 46 |
-
# CUDA env — set before importing torch in entry point. Kept here so any
|
| 47 |
-
# module that `from hydra.config import ...` also benefits (import order is
|
| 48 |
-
# top-down in Python, and train.py used to set these at module top).
|
| 49 |
-
# ---------------------------------------------------------------------------
|
| 50 |
-
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
|
| 51 |
-
if "/usr/local/cuda/bin" not in os.environ.get("PATH", ""):
|
| 52 |
-
os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "")
|
| 53 |
-
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
# ---------------------------------------------------------------------------
|
| 57 |
-
# Model Configuration
|
| 58 |
-
# ---------------------------------------------------------------------------
|
| 59 |
-
|
| 60 |
-
@dataclass
|
| 61 |
-
class PostSemClawConfig:
|
| 62 |
-
"""Full-architecture model config. Defaults reflect Phase-1 baseline;
|
| 63 |
-
the training entry overrides d_model/n_layer/etc. from env vars."""
|
| 64 |
-
# Sequence
|
| 65 |
-
sequence_len: int = 2048
|
| 66 |
-
vocab_size: int = 8192 # Must match prepare.py VOCAB_SIZE
|
| 67 |
-
|
| 68 |
-
# Mamba-3 SSM
|
| 69 |
-
n_layer: int = 6
|
| 70 |
-
d_model: int = 384
|
| 71 |
-
d_state: int = 64 # SSM state dimension
|
| 72 |
-
headdim: int = 48 # head dimension for SSM
|
| 73 |
-
n_heads: int = 8 # d_model // headdim
|
| 74 |
-
expand: int = 2 # inner_dim = expand * d_model
|
| 75 |
-
|
| 76 |
-
# Engram (conditional memory with Hebbian writes)
|
| 77 |
-
engram_n_columns: int = 4096
|
| 78 |
-
engram_key_dim: int = 64
|
| 79 |
-
engram_layer_idx: int = 1 # which layer gets engram (0-indexed, mid-layer)
|
| 80 |
-
|
| 81 |
-
# SemanticFoldingSDR (offline retina with STE; no-bypass, runs every step)
|
| 82 |
-
sdr_n_bits: int = 16384 # retina width
|
| 83 |
-
# Default 327 = 2% sparsity (Webber/Numenta canonical). Override with
|
| 84 |
-
# HYDRA_SDR_TARGET_ACTIVE env var; value MUST match subsystems/sdr_retina.py
|
| 85 |
-
# TARGET_ACTIVE (same env var is read there, so just setting it once works).
|
| 86 |
-
sdr_target_active: int = int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327"))
|
| 87 |
-
sdr_delta_rank: int = 32 # low-rank STE delta rank
|
| 88 |
-
sdr_som_warmup: int = 500
|
| 89 |
-
sdr_som_interval: int = 100
|
| 90 |
-
|
| 91 |
-
# HTMLayer (Rust-backed, Hebbian; no-bypass, runs every step)
|
| 92 |
-
htm_n_columns: int = 2048
|
| 93 |
-
htm_cells_per_column: int = 32
|
| 94 |
-
|
| 95 |
-
# Hyena supplement layer indices (sorted tuple). Defaults to the
|
| 96 |
-
# HYDRA_HYENA_LAYERS env var at config-construction time, but once
|
| 97 |
-
# persisted in a checkpoint the value is first-class and survives even
|
| 98 |
-
# when the env var is unset at resume time. This fixes the ckpt-reload
|
| 99 |
-
# crash path where a model trained with `HYDRA_HYENA_LAYERS=3,7` saves
|
| 100 |
-
# HyenaBlock params but a fresh process without the env var would try
|
| 101 |
-
# to build a pure-Mamba3 architecture and reject the state_dict as
|
| 102 |
-
# `Missing/Unexpected key(s)`.
|
| 103 |
-
hyena_layers: tuple[int, ...] = field(default_factory=_parse_hyena_layers_env)
|
| 104 |
-
|
| 105 |
-
# GatedDeltaNet supplement layer indices (sorted tuple). Same semantics
|
| 106 |
-
# as hyena_layers — a layer index listed here uses GDNBlock (fla-backed
|
| 107 |
-
# Gated DeltaNet) instead of Mamba3. Selections are mutually exclusive
|
| 108 |
-
# with hyena_layers at construction time (hyena wins on overlap; the
|
| 109 |
-
# model loop checks hyena first).
|
| 110 |
-
gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
|
| 111 |
-
|
| 112 |
-
# Label smoothing + Z-loss
|
| 113 |
-
label_smoothing: float =
|
| 114 |
-
z_loss_weight: float = 1e-4
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# ---------------------------------------------------------------------------
|
| 118 |
-
# Hyperparameters (autoresearch agent modifies these via env vars)
|
| 119 |
-
# ---------------------------------------------------------------------------
|
| 120 |
-
|
| 121 |
-
# Model architecture
|
| 122 |
-
D_MODEL = int(os.environ.get("HYDRA_D_MODEL", "256"))
|
| 123 |
-
N_LAYER = int(os.environ.get("HYDRA_N_LAYER", "4"))
|
| 124 |
-
D_STATE = int(os.environ.get("HYDRA_D_STATE", "64"))
|
| 125 |
-
HEADDIM = int(os.environ.get("HYDRA_HEADDIM", "32"))
|
| 126 |
-
N_HEADS = D_MODEL // HEADDIM
|
| 127 |
-
EXPAND = int(os.environ.get("HYDRA_EXPAND", "2"))
|
| 128 |
-
|
| 129 |
-
# Engram
|
| 130 |
-
ENGRAM_N_COLUMNS = int(os.environ.get("HYDRA_ENGRAM_N_COLUMNS", "1024"))
|
| 131 |
-
ENGRAM_KEY_DIM = 64
|
| 132 |
-
ENGRAM_LAYER_IDX = int(os.environ.get("HYDRA_ENGRAM_LAYER_IDX", "1"))
|
| 133 |
-
|
| 134 |
-
# Optimization
|
| 135 |
-
DEVICE_BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
|
| 136 |
-
TOTAL_BATCH_SIZE = int(os.environ.get("HYDRA_TOTAL_BATCH", "32768"))
|
| 137 |
-
MATRIX_LR = float(os.environ.get("HYDRA_MATRIX_LR", "0.12"))
|
| 138 |
-
EMBEDDING_LR = float(os.environ.get("HYDRA_EMBED_LR", "1.0"))
|
| 139 |
-
UNEMBEDDING_LR = float(os.environ.get("HYDRA_UNEMBED_LR", "0.005"))
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
#
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
#
|
| 164 |
-
#
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
#
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
#
|
| 176 |
-
#
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
#
|
| 180 |
-
|
| 181 |
-
#
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
#
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
#
|
| 192 |
-
#
|
| 193 |
-
#
|
| 194 |
-
#
|
| 195 |
-
# -
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
#
|
| 200 |
-
#
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
#
|
| 205 |
-
#
|
| 206 |
-
#
|
| 207 |
-
#
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
#
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
#
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HYDRA training configuration — dataclass + env-var constants.
|
| 2 |
+
|
| 3 |
+
Extracted from the monolithic train.py as part of W1 modularization. All
|
| 4 |
+
env-var reads and the PostSemClawConfig dataclass live here. The training
|
| 5 |
+
body imports these constants; zero behavior change from the extraction.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _parse_hyena_layers_env() -> tuple[int, ...]:
|
| 15 |
+
"""Parse HYDRA_HYENA_LAYERS env var into a sorted tuple of layer indices.
|
| 16 |
+
|
| 17 |
+
Used as the default_factory for PostSemClawConfig.hyena_layers so a fresh
|
| 18 |
+
config construction reads the current env var, but once constructed the
|
| 19 |
+
value is first-class and travels with checkpoints (see asdict(config) in
|
| 20 |
+
save_ckpt). Ckpt-load sets the dataclass field explicitly, overriding the
|
| 21 |
+
env-var default.
|
| 22 |
+
|
| 23 |
+
Returns empty tuple when env var is unset/empty (byte-identical to
|
| 24 |
+
pre-port behavior: no Hyena layers).
|
| 25 |
+
"""
|
| 26 |
+
raw = os.environ.get("HYDRA_HYENA_LAYERS", "")
|
| 27 |
+
if not raw:
|
| 28 |
+
return ()
|
| 29 |
+
return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _parse_gdn_layers_env() -> tuple[int, ...]:
|
| 33 |
+
"""Parse HYDRA_GDN_LAYERS env var into a sorted tuple of layer indices.
|
| 34 |
+
|
| 35 |
+
Same contract as _parse_hyena_layers_env: layers whose index is listed
|
| 36 |
+
here use GatedDeltaNet (fla.layers.GatedDeltaNet) as a drop-in
|
| 37 |
+
replacement for Mamba3. Empty tuple = no GDN layers (byte-identical
|
| 38 |
+
to baseline).
|
| 39 |
+
"""
|
| 40 |
+
raw = os.environ.get("HYDRA_GDN_LAYERS", "")
|
| 41 |
+
if not raw:
|
| 42 |
+
return ()
|
| 43 |
+
return tuple(sorted({int(s.strip()) for s in raw.split(",") if s.strip()}))
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# CUDA env — set before importing torch in entry point. Kept here so any
|
| 47 |
+
# module that `from hydra.config import ...` also benefits (import order is
|
| 48 |
+
# top-down in Python, and train.py used to set these at module top).
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
|
| 51 |
+
if "/usr/local/cuda/bin" not in os.environ.get("PATH", ""):
|
| 52 |
+
os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "")
|
| 53 |
+
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# Model Configuration
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class PostSemClawConfig:
|
| 62 |
+
"""Full-architecture model config. Defaults reflect Phase-1 baseline;
|
| 63 |
+
the training entry overrides d_model/n_layer/etc. from env vars."""
|
| 64 |
+
# Sequence
|
| 65 |
+
sequence_len: int = 2048
|
| 66 |
+
vocab_size: int = 8192 # Must match prepare.py VOCAB_SIZE
|
| 67 |
+
|
| 68 |
+
# Mamba-3 SSM
|
| 69 |
+
n_layer: int = 6
|
| 70 |
+
d_model: int = 384
|
| 71 |
+
d_state: int = 64 # SSM state dimension
|
| 72 |
+
headdim: int = 48 # head dimension for SSM
|
| 73 |
+
n_heads: int = 8 # d_model // headdim
|
| 74 |
+
expand: int = 2 # inner_dim = expand * d_model
|
| 75 |
+
|
| 76 |
+
# Engram (conditional memory with Hebbian writes)
|
| 77 |
+
engram_n_columns: int = 4096
|
| 78 |
+
engram_key_dim: int = 64
|
| 79 |
+
engram_layer_idx: int = 1 # which layer gets engram (0-indexed, mid-layer)
|
| 80 |
+
|
| 81 |
+
# SemanticFoldingSDR (offline retina with STE; no-bypass, runs every step)
|
| 82 |
+
sdr_n_bits: int = 16384 # retina width
|
| 83 |
+
# Default 327 = 2% sparsity (Webber/Numenta canonical). Override with
|
| 84 |
+
# HYDRA_SDR_TARGET_ACTIVE env var; value MUST match subsystems/sdr_retina.py
|
| 85 |
+
# TARGET_ACTIVE (same env var is read there, so just setting it once works).
|
| 86 |
+
sdr_target_active: int = int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327"))
|
| 87 |
+
sdr_delta_rank: int = 32 # low-rank STE delta rank
|
| 88 |
+
sdr_som_warmup: int = 500
|
| 89 |
+
sdr_som_interval: int = 100
|
| 90 |
+
|
| 91 |
+
# HTMLayer (Rust-backed, Hebbian; no-bypass, runs every step)
|
| 92 |
+
htm_n_columns: int = 2048
|
| 93 |
+
htm_cells_per_column: int = 32
|
| 94 |
+
|
| 95 |
+
# Hyena supplement layer indices (sorted tuple). Defaults to the
|
| 96 |
+
# HYDRA_HYENA_LAYERS env var at config-construction time, but once
|
| 97 |
+
# persisted in a checkpoint the value is first-class and survives even
|
| 98 |
+
# when the env var is unset at resume time. This fixes the ckpt-reload
|
| 99 |
+
# crash path where a model trained with `HYDRA_HYENA_LAYERS=3,7` saves
|
| 100 |
+
# HyenaBlock params but a fresh process without the env var would try
|
| 101 |
+
# to build a pure-Mamba3 architecture and reject the state_dict as
|
| 102 |
+
# `Missing/Unexpected key(s)`.
|
| 103 |
+
hyena_layers: tuple[int, ...] = field(default_factory=_parse_hyena_layers_env)
|
| 104 |
+
|
| 105 |
+
# GatedDeltaNet supplement layer indices (sorted tuple). Same semantics
|
| 106 |
+
# as hyena_layers — a layer index listed here uses GDNBlock (fla-backed
|
| 107 |
+
# Gated DeltaNet) instead of Mamba3. Selections are mutually exclusive
|
| 108 |
+
# with hyena_layers at construction time (hyena wins on overlap; the
|
| 109 |
+
# model loop checks hyena first).
|
| 110 |
+
gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
|
| 111 |
+
|
| 112 |
+
# Label smoothing + Z-loss
|
| 113 |
+
label_smoothing: float = field(default_factory=lambda: float(os.environ.get("HYDRA_LABEL_SMOOTHING", "0.0")))
|
| 114 |
+
z_loss_weight: float = field(default_factory=lambda: float(os.environ.get("HYDRA_Z_LOSS_WEIGHT", "1e-4")))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
# Hyperparameters (autoresearch agent modifies these via env vars)
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
|
| 121 |
+
# Model architecture
|
| 122 |
+
D_MODEL = int(os.environ.get("HYDRA_D_MODEL", "256"))
|
| 123 |
+
N_LAYER = int(os.environ.get("HYDRA_N_LAYER", "4"))
|
| 124 |
+
D_STATE = int(os.environ.get("HYDRA_D_STATE", "64"))
|
| 125 |
+
HEADDIM = int(os.environ.get("HYDRA_HEADDIM", "32"))
|
| 126 |
+
N_HEADS = D_MODEL // HEADDIM
|
| 127 |
+
EXPAND = int(os.environ.get("HYDRA_EXPAND", "2"))
|
| 128 |
+
|
| 129 |
+
# Engram
|
| 130 |
+
ENGRAM_N_COLUMNS = int(os.environ.get("HYDRA_ENGRAM_N_COLUMNS", "1024"))
|
| 131 |
+
ENGRAM_KEY_DIM = 64
|
| 132 |
+
ENGRAM_LAYER_IDX = int(os.environ.get("HYDRA_ENGRAM_LAYER_IDX", "1"))
|
| 133 |
+
|
| 134 |
+
# Optimization
|
| 135 |
+
DEVICE_BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
|
| 136 |
+
TOTAL_BATCH_SIZE = int(os.environ.get("HYDRA_TOTAL_BATCH", "32768"))
|
| 137 |
+
MATRIX_LR = float(os.environ.get("HYDRA_MATRIX_LR", "0.12"))
|
| 138 |
+
EMBEDDING_LR = float(os.environ.get("HYDRA_EMBED_LR", "1.0"))
|
| 139 |
+
UNEMBEDDING_LR = float(os.environ.get("HYDRA_UNEMBED_LR", "0.005"))
|
| 140 |
+
# Scalar/vector params include Hyena implicit-filter vectors, norms, gate/bias
|
| 141 |
+
# terms, and SDR delta_u/delta_v. They are AdamW-scaled by d_model and can be
|
| 142 |
+
# the hidden instability path when the high-throughput HF recipe pushes a large
|
| 143 |
+
# device batch for hours. Keep the historical default, but make it controllable
|
| 144 |
+
# from launch scripts so cloud jobs can cool scalars without editing code.
|
| 145 |
+
SCALAR_LR = float(os.environ.get("HYDRA_SCALAR_LR", "0.5"))
|
| 146 |
+
WEIGHT_DECAY = float(os.environ.get("HYDRA_WEIGHT_DECAY", "0.01"))
|
| 147 |
+
ADAM_BETAS = (0.9, 0.95)
|
| 148 |
+
WARMUP_RATIO = float(os.environ.get("HYDRA_WARMUP_RATIO", "0.0"))
|
| 149 |
+
WARMDOWN_RATIO = 0.5
|
| 150 |
+
FINAL_LR_FRAC = float(os.environ.get("HYDRA_LR_MIN_MULT", "0.0"))
|
| 151 |
+
|
| 152 |
+
# Runtime
|
| 153 |
+
SEED = int(os.environ.get("HYDRA_SEED", "42"))
|
| 154 |
+
# BF16 TFLOPS peak (RTX 3060=25.5, A100 SXM4=312, H100 SXM5=989)
|
| 155 |
+
GPU_BF16_PEAK_FLOPS = float(os.environ.get("HYDRA_GPU_BF16_TFLOPS", "25.5")) * 1e12
|
| 156 |
+
|
| 157 |
+
# Loss / inference knobs read by the model
|
| 158 |
+
CE_CHUNK = int(os.environ.get("HYDRA_CE_CHUNK", "1024"))
|
| 159 |
+
DROPOUT = float(os.environ.get("HYDRA_DROPOUT", "0.2"))
|
| 160 |
+
FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# Learnability knobs (all OFF by default — zero behavior change unless set)
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
# 1) Multi-Token Prediction (Llama-3 style). K=1 disables (next-1 only). K=4
|
| 166 |
+
# adds 3 extra weight-tied heads; loss = mean of K position-shifted CEs.
|
| 167 |
+
MTP_K = int(os.environ.get("HYDRA_MTP_K", "1"))
|
| 168 |
+
# 2) Exponential Moving Average of model weights (decay=0.999). Saves an
|
| 169 |
+
# additional latest_ema.pt at the end of training.
|
| 170 |
+
USE_EMA = os.environ.get("HYDRA_USE_EMA", "0") == "1"
|
| 171 |
+
EMA_DECAY = float(os.environ.get("HYDRA_EMA_DECAY", "0.999"))
|
| 172 |
+
# 3) Gradient checkpointing on Mamba3 block forward. Trades ~30% compute for
|
| 173 |
+
# ~40% activation memory savings — lets you push B upward on a 3060.
|
| 174 |
+
GRAD_CKPT = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
|
| 175 |
+
# 4) Doc-separator masking in packed sequences: at every packed-BOS position
|
| 176 |
+
# in the targets tensor, mask the loss (ignore_index=-1) so the model is
|
| 177 |
+
# not forced to predict doc B from doc A's context.
|
| 178 |
+
DOC_SEP_MASK = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
|
| 179 |
+
# 5) Stop-gradient on HTM state (belt-and-braces: htm_rust already runs under
|
| 180 |
+
# torch.no_grad() so the tensor returned has requires_grad=False; this
|
| 181 |
+
# simply detaches explicitly to harden graph hygiene against future refactors).
|
| 182 |
+
HTM_STOP_GRAD = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1"
|
| 183 |
+
# 6) Output entropy penalty: loss += -lambda * H(softmax(logits)). Negative
|
| 184 |
+
# entropy penalizes peaked distributions and breaks repetition loops.
|
| 185 |
+
ENTROPY_PENALTY = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0"))
|
| 186 |
+
# 7) Curriculum: first N optimizer steps use short seq_len, then switch to
|
| 187 |
+
# full. 0 disables (no curriculum).
|
| 188 |
+
CURRICULUM_SHORT_STEPS = int(os.environ.get("HYDRA_CURRICULUM_SHORT_STEPS", "0"))
|
| 189 |
+
CURRICULUM_SHORT_SEQ_LEN = int(os.environ.get("HYDRA_CURRICULUM_SHORT_SEQ_LEN", "256"))
|
| 190 |
+
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
# Hyena supplement (additional block type for selected layer indices).
|
| 193 |
+
# Hyena replaces Mamba3 at the specified layer indices while all other layers
|
| 194 |
+
# remain Mamba3. Empty string (default) → no Hyena layers, byte-identical to
|
| 195 |
+
# pre-port behavior.
|
| 196 |
+
# HYDRA_HYENA_LAYERS "3,7" — comma-separated 0-indexed layer ids
|
| 197 |
+
# HYDRA_HYENA_ORDER 2 — Hyena recurrence order (>= 2)
|
| 198 |
+
# HYDRA_HYENA_FILTER_DIM 64 — implicit-filter MLP hidden width
|
| 199 |
+
# Hyena reference: https://arxiv.org/pdf/2302.10866.pdf (HazyResearch/safari).
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
HYENA_LAYERS = os.environ.get("HYDRA_HYENA_LAYERS", "")
|
| 202 |
+
HYENA_ORDER = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
|
| 203 |
+
HYENA_FILTER_DIM = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
|
| 204 |
+
# Filter-rfft cache modes (see subsystems/hyena_pure.py):
|
| 205 |
+
# HYDRA_HYENA_FILTER_CACHE=1 — eval-only cache. Safe under torch.no_grad()
|
| 206 |
+
# where PyTorch never saves intermediate tensors. Off by default.
|
| 207 |
+
# HYDRA_HYENA_TRAIN_CACHE=1 — training-safe cache using a deferred
|
| 208 |
+
# gradient pattern. Cuts the implicit filter MLP forward to ONCE per
|
| 209 |
+
# optimizer step regardless of grad-accumulation factor. Requires the
|
| 210 |
+
# training loop (see hydra/lightning_module.py::optimizer_step) to
|
| 211 |
+
# call `model.flush_hyena_pending_grads()` before optimizer.step().
|
| 212 |
+
# Off by default.
|
| 213 |
+
HYENA_FILTER_CACHE = os.environ.get("HYDRA_HYENA_FILTER_CACHE", "0") == "1"
|
| 214 |
+
HYENA_TRAIN_CACHE = os.environ.get("HYDRA_HYENA_TRAIN_CACHE", "0") == "1"
|
| 215 |
+
|
| 216 |
+
# Factual eval knobs
|
| 217 |
+
FACTUAL_SAMPLES = int(os.environ.get("HYDRA_FACTUAL_SAMPLES", "3"))
|
| 218 |
+
FACTUAL_BATCH = int(os.environ.get("HYDRA_FACTUAL_BATCH", "32"))
|
| 219 |
+
# F6 (partial): Full incremental SSM decode integration deferred — would require
|
| 220 |
+
# threading mamba_ssm InferenceParams through PostSemClawModel.forward and all
|
| 221 |
+
# auxiliary subsystems (HTM, SDR, Engram) which currently run full-sequence each
|
| 222 |
+
# call. As a stopgap we reduce default from 16 -> 4 so the per-prompt cost is
|
| 223 |
+
# quartered (each gen-tok does a full re-encode of ctx+k tokens). Override with
|
| 224 |
+
# HYDRA_FACTUAL_GEN_TOKENS to restore prior behavior. See docs/OPTIMIZATION_PLAN.md.
|
| 225 |
+
FACTUAL_GEN_TOKENS = int(os.environ.get("HYDRA_FACTUAL_GEN_TOKENS", "2"))
|
overlay/hydra/data_module.py
CHANGED
|
@@ -1,288 +1,288 @@
|
|
| 1 |
-
"""Lightning DataModule + IterableDataset for HYDRA pretraining.
|
| 2 |
-
|
| 3 |
-
Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader
|
| 4 |
-
with a standard multiprocessing DataLoader approach.
|
| 5 |
-
|
| 6 |
-
Design:
|
| 7 |
-
• IterableStreamDataset: each worker opens its own HF streams for the 7-way
|
| 8 |
-
blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and
|
| 9 |
-
yields one row per __next__.
|
| 10 |
-
• HydraDataModule: wraps the dataset with a standard DataLoader using
|
| 11 |
-
num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles
|
| 12 |
-
device transfer.
|
| 13 |
-
• Val stream: deterministic seed 12345, weights match training blend.
|
| 14 |
-
|
| 15 |
-
The worker RNG is seeded per-worker so the weighted-sampling schedule is
|
| 16 |
-
independent across workers (else all workers request the same config at
|
| 17 |
-
the same step and prefetching serializes).
|
| 18 |
-
|
| 19 |
-
Env vars (all preserved from prepare_nemotron):
|
| 20 |
-
HYDRA_SEQ_LEN — sequence length T (default 512)
|
| 21 |
-
HYDRA_BATCH_SIZE — batch size B (default 1) — passed through
|
| 22 |
-
to DataLoader
|
| 23 |
-
HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048)
|
| 24 |
-
HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase
|
| 25 |
-
HYDRA_USE_NEMOTRON — enables streaming path (else shard path)
|
| 26 |
-
HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence
|
| 27 |
-
HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend)
|
| 28 |
-
HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2)
|
| 29 |
-
HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4)
|
| 30 |
-
HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing
|
| 31 |
-
(default 1000)
|
| 32 |
-
"""
|
| 33 |
-
from __future__ import annotations
|
| 34 |
-
|
| 35 |
-
import os
|
| 36 |
-
import random
|
| 37 |
-
from typing import Iterator
|
| 38 |
-
|
| 39 |
-
import numpy as np
|
| 40 |
-
import torch
|
| 41 |
-
import lightning as L
|
| 42 |
-
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| 43 |
-
|
| 44 |
-
import prepare as _prepare
|
| 45 |
-
import prepare_nemotron as _p_nemo
|
| 46 |
-
from prepare_nemotron import (
|
| 47 |
-
FULL_BLEND_WEIGHTS,
|
| 48 |
-
PHASE1_WEIGHTS,
|
| 49 |
-
PHASE2_WEIGHTS,
|
| 50 |
-
_BLEND_REGISTRY,
|
| 51 |
-
_extract_text,
|
| 52 |
-
_open_stream,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
# ---------------------------------------------------------------------------
|
| 57 |
-
# Worker-local weighted stream. A stripped version of prepare_nemotron's
|
| 58 |
-
# _WeightedStream that is constructed inside each worker. Adds worker sharding:
|
| 59 |
-
# when num_workers > 1 the RNG is seeded per-worker, so different workers
|
| 60 |
-
# sample different config sequences and pull disjoint shard assignments from
|
| 61 |
-
# HF's shuffle buffer.
|
| 62 |
-
# ---------------------------------------------------------------------------
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class _WorkerWeightedStream:
|
| 66 |
-
def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int):
|
| 67 |
-
self.configs = list(weights.keys())
|
| 68 |
-
self.weights = [weights[c] for c in self.configs]
|
| 69 |
-
self.base_seed = base_seed
|
| 70 |
-
self.worker_id = worker_id
|
| 71 |
-
# Each worker opens its own HF streams. _open_stream returns an iter()
|
| 72 |
-
# over a streaming dataset, with an internal shuffle buffer.
|
| 73 |
-
self.streams = {c: _open_stream(c, "train") for c in self.configs}
|
| 74 |
-
# Per-worker RNG so the config-choice trajectory is independent.
|
| 75 |
-
self.rng = random.Random(base_seed + worker_id * 7919)
|
| 76 |
-
self.epoch = 1
|
| 77 |
-
|
| 78 |
-
# Lazy-init factual docs (once per worker). The main-process version
|
| 79 |
-
# in prepare_nemotron._WeightedStream reads these on first __next__.
|
| 80 |
-
self._factual_docs: list[str] | None = None
|
| 81 |
-
self._factual_idx = 0
|
| 82 |
-
self._inject_counter = 0
|
| 83 |
-
inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
|
| 84 |
-
self._inject_rate = inject_rate
|
| 85 |
-
if inject_rate > 0:
|
| 86 |
-
factual_path = os.path.join(
|
| 87 |
-
os.path.dirname(os.path.abspath(_p_nemo.__file__)),
|
| 88 |
-
"data", "factual", "facts.txt",
|
| 89 |
-
)
|
| 90 |
-
if os.path.exists(factual_path):
|
| 91 |
-
with open(factual_path) as fh:
|
| 92 |
-
self._factual_docs = fh.read().strip().split("\n")
|
| 93 |
-
|
| 94 |
-
def _reopen(self, config: str) -> None:
|
| 95 |
-
self.streams[config] = _open_stream(config, "train")
|
| 96 |
-
self.epoch += 1
|
| 97 |
-
|
| 98 |
-
def __iter__(self):
|
| 99 |
-
return self
|
| 100 |
-
|
| 101 |
-
def __next__(self) -> tuple[str, int]:
|
| 102 |
-
# Factual injection (preserves prepare_nemotron cadence).
|
| 103 |
-
if self._inject_rate > 0 and self._factual_docs:
|
| 104 |
-
self._inject_counter += 1
|
| 105 |
-
if self._inject_counter >= self._inject_rate:
|
| 106 |
-
self._inject_counter = 0
|
| 107 |
-
doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
|
| 108 |
-
self._factual_idx += 1
|
| 109 |
-
return doc, self.epoch
|
| 110 |
-
|
| 111 |
-
config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
|
| 112 |
-
try:
|
| 113 |
-
row = next(self.streams[config])
|
| 114 |
-
except StopIteration:
|
| 115 |
-
self._reopen(config)
|
| 116 |
-
row = next(self.streams[config])
|
| 117 |
-
return _extract_text(row), self.epoch
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
# ---------------------------------------------------------------------------
|
| 121 |
-
# IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues.
|
| 122 |
-
# Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks
|
| 123 |
-
# rows into batches of shape (B, T+1) and sends them to the main process.
|
| 124 |
-
# ---------------------------------------------------------------------------
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
class IterableStreamDataset(IterableDataset):
|
| 128 |
-
"""Streams docs, tokenizes, packs into (T+1,) rows via best-fit.
|
| 129 |
-
|
| 130 |
-
Each worker gets its own instance (via fork/spawn) and initializes its
|
| 131 |
-
own HF streams + rustbpe tokenizer + factual injector. The tokenizer
|
| 132 |
-
pickled blob is small (~1 MB) and thread-safe per tiktoken docs.
|
| 133 |
-
"""
|
| 134 |
-
|
| 135 |
-
def __init__(
|
| 136 |
-
self,
|
| 137 |
-
split: str,
|
| 138 |
-
seq_len: int,
|
| 139 |
-
*,
|
| 140 |
-
base_seed: int = 0,
|
| 141 |
-
doc_buffer_size: int = 1000,
|
| 142 |
-
tokenizer_batch: int = 128,
|
| 143 |
-
):
|
| 144 |
-
super().__init__()
|
| 145 |
-
assert split in ("train", "val"), split
|
| 146 |
-
self.split = split
|
| 147 |
-
self.seq_len = seq_len
|
| 148 |
-
self.row_capacity = seq_len + 1
|
| 149 |
-
self.base_seed = base_seed
|
| 150 |
-
self.doc_buffer_size = doc_buffer_size
|
| 151 |
-
self.tokenizer_batch = tokenizer_batch
|
| 152 |
-
|
| 153 |
-
def _pick_weights(self) -> dict[str, float]:
|
| 154 |
-
if self.split == "val":
|
| 155 |
-
if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
|
| 156 |
-
return FULL_BLEND_WEIGHTS
|
| 157 |
-
return {"Nemotron-Pretraining-Multiple-Choice": 1.0}
|
| 158 |
-
if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
|
| 159 |
-
return FULL_BLEND_WEIGHTS
|
| 160 |
-
phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower()
|
| 161 |
-
return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
|
| 162 |
-
|
| 163 |
-
def __iter__(self) -> Iterator[torch.Tensor]:
|
| 164 |
-
info = get_worker_info()
|
| 165 |
-
worker_id = 0 if info is None else info.id
|
| 166 |
-
|
| 167 |
-
# Each worker builds its own tokenizer instance. tiktoken's Encoding
|
| 168 |
-
# object is pickleable and the underlying C++ BPE is thread-safe;
|
| 169 |
-
# per-worker instantiation avoids cross-process sharing headaches.
|
| 170 |
-
tokenizer = _prepare.Tokenizer.from_directory()
|
| 171 |
-
bos = tokenizer.get_bos_token_id()
|
| 172 |
-
|
| 173 |
-
# Each worker gets its own weighted HF stream. Seed offset ensures
|
| 174 |
-
# disjoint config-choice trajectories; HF's own shuffle buffer handles
|
| 175 |
-
# shard randomization.
|
| 176 |
-
val_seed = 12345 # deterministic val
|
| 177 |
-
seed = val_seed if self.split == "val" else self.base_seed
|
| 178 |
-
stream = _WorkerWeightedStream(
|
| 179 |
-
self._pick_weights(), base_seed=seed, worker_id=worker_id,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
row_capacity = self.row_capacity
|
| 183 |
-
doc_buffer: list[list[int]] = []
|
| 184 |
-
doc_batch_size = self.tokenizer_batch
|
| 185 |
-
|
| 186 |
-
def refill_buffer() -> None:
|
| 187 |
-
# Collect doc_batch_size text strings, then batch-tokenize.
|
| 188 |
-
texts: list[str] = []
|
| 189 |
-
for _ in range(doc_batch_size):
|
| 190 |
-
text, _epoch = next(stream)
|
| 191 |
-
if text:
|
| 192 |
-
texts.append(text)
|
| 193 |
-
if texts:
|
| 194 |
-
token_lists = tokenizer.encode(texts, prepend=bos)
|
| 195 |
-
doc_buffer.extend(token_lists)
|
| 196 |
-
|
| 197 |
-
while True:
|
| 198 |
-
pos = 0
|
| 199 |
-
row = torch.empty(row_capacity, dtype=torch.long)
|
| 200 |
-
while pos < row_capacity:
|
| 201 |
-
while len(doc_buffer) < self.doc_buffer_size:
|
| 202 |
-
refill_buffer()
|
| 203 |
-
|
| 204 |
-
remaining = row_capacity - pos
|
| 205 |
-
|
| 206 |
-
# Best-fit packing: largest doc that fully fits.
|
| 207 |
-
best_idx = -1
|
| 208 |
-
best_len = 0
|
| 209 |
-
for i, doc in enumerate(doc_buffer):
|
| 210 |
-
dlen = len(doc)
|
| 211 |
-
if dlen <= remaining and dlen > best_len:
|
| 212 |
-
best_idx = i
|
| 213 |
-
best_len = dlen
|
| 214 |
-
|
| 215 |
-
if best_idx >= 0:
|
| 216 |
-
doc = doc_buffer.pop(best_idx)
|
| 217 |
-
row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
|
| 218 |
-
pos += len(doc)
|
| 219 |
-
else:
|
| 220 |
-
# No doc fits remaining space — crop shortest to fill.
|
| 221 |
-
shortest_idx = min(
|
| 222 |
-
range(len(doc_buffer)),
|
| 223 |
-
key=lambda i: len(doc_buffer[i]),
|
| 224 |
-
)
|
| 225 |
-
doc = doc_buffer.pop(shortest_idx)
|
| 226 |
-
row[pos : pos + remaining] = torch.tensor(
|
| 227 |
-
doc[:remaining], dtype=torch.long,
|
| 228 |
-
)
|
| 229 |
-
pos += remaining
|
| 230 |
-
|
| 231 |
-
yield row
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
# ---------------------------------------------------------------------------
|
| 235 |
-
# LightningDataModule
|
| 236 |
-
# ---------------------------------------------------------------------------
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class HydraDataModule(L.LightningDataModule):
|
| 240 |
-
def __init__(
|
| 241 |
-
self,
|
| 242 |
-
batch_size: int | None = None,
|
| 243 |
-
seq_len: int | None = None,
|
| 244 |
-
num_workers: int | None = None,
|
| 245 |
-
prefetch_factor: int | None = None,
|
| 246 |
-
):
|
| 247 |
-
super().__init__()
|
| 248 |
-
self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
|
| 249 |
-
self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512"))
|
| 250 |
-
self.num_workers = (
|
| 251 |
-
num_workers
|
| 252 |
-
if num_workers is not None
|
| 253 |
-
else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2"))
|
| 254 |
-
)
|
| 255 |
-
self.prefetch_factor = (
|
| 256 |
-
prefetch_factor
|
| 257 |
-
if prefetch_factor is not None
|
| 258 |
-
else int(os.environ.get("HYDRA_DATA_PREFETCH", "4"))
|
| 259 |
-
)
|
| 260 |
-
self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000"))
|
| 261 |
-
|
| 262 |
-
def _make_loader(self, split: str, seed: int) -> DataLoader:
|
| 263 |
-
dataset = IterableStreamDataset(
|
| 264 |
-
split=split,
|
| 265 |
-
seq_len=self.seq_len,
|
| 266 |
-
base_seed=seed,
|
| 267 |
-
doc_buffer_size=self.doc_buffer,
|
| 268 |
-
)
|
| 269 |
-
# num_workers=0 → main-process iteration (useful for debugging). With
|
| 270 |
-
# IterableDataset the DataLoader batches the rows into (B, T+1) via
|
| 271 |
-
# default torch.stack-collate.
|
| 272 |
-
kw: dict = dict(
|
| 273 |
-
dataset=dataset,
|
| 274 |
-
batch_size=self.batch_size,
|
| 275 |
-
num_workers=self.num_workers,
|
| 276 |
-
pin_memory=True,
|
| 277 |
-
drop_last=True,
|
| 278 |
-
)
|
| 279 |
-
if self.num_workers > 0:
|
| 280 |
-
kw["prefetch_factor"] = self.prefetch_factor
|
| 281 |
-
kw["persistent_workers"] = True
|
| 282 |
-
return DataLoader(**kw)
|
| 283 |
-
|
| 284 |
-
def train_dataloader(self) -> DataLoader:
|
| 285 |
-
return self._make_loader("train", seed=0)
|
| 286 |
-
|
| 287 |
-
def val_dataloader(self) -> DataLoader:
|
| 288 |
-
return self._make_loader("val", seed=12345)
|
|
|
|
| 1 |
+
"""Lightning DataModule + IterableDataset for HYDRA pretraining.
|
| 2 |
+
|
| 3 |
+
Replaces the custom threading/queue pipeline in prepare_nemotron.make_dataloader
|
| 4 |
+
with a standard multiprocessing DataLoader approach.
|
| 5 |
+
|
| 6 |
+
Design:
|
| 7 |
+
• IterableStreamDataset: each worker opens its own HF streams for the 7-way
|
| 8 |
+
blend, tokenizes with rustbpe, packs into (T+1,) rows via best-fit, and
|
| 9 |
+
yields one row per __next__.
|
| 10 |
+
• HydraDataModule: wraps the dataset with a standard DataLoader using
|
| 11 |
+
num_workers>=1, prefetch_factor=4, pin_memory=True. Lightning handles
|
| 12 |
+
device transfer.
|
| 13 |
+
• Val stream: deterministic seed 12345, weights match training blend.
|
| 14 |
+
|
| 15 |
+
The worker RNG is seeded per-worker so the weighted-sampling schedule is
|
| 16 |
+
independent across workers (else all workers request the same config at
|
| 17 |
+
the same step and prefetching serializes).
|
| 18 |
+
|
| 19 |
+
Env vars (all preserved from prepare_nemotron):
|
| 20 |
+
HYDRA_SEQ_LEN — sequence length T (default 512)
|
| 21 |
+
HYDRA_BATCH_SIZE — batch size B (default 1) — passed through
|
| 22 |
+
to DataLoader
|
| 23 |
+
HYDRA_STREAM_SHUFFLE_BUFFER — HF shuffle buffer (default 2048)
|
| 24 |
+
HYDRA_USE_FULL_BLEND — 7-way blend vs 5-way Nemotron phase
|
| 25 |
+
HYDRA_USE_NEMOTRON — enables streaming path (else shard path)
|
| 26 |
+
HYDRA_FACTUAL_INJECT_RATE — factual doc injection cadence
|
| 27 |
+
HYDRA_NEMOTRON_PHASE — phase1|phase2 (when not full blend)
|
| 28 |
+
HYDRA_DATA_NUM_WORKERS — DataLoader num_workers (default 2)
|
| 29 |
+
HYDRA_DATA_PREFETCH — DataLoader prefetch_factor (default 4)
|
| 30 |
+
HYDRA_DATA_BUFFER — doc_buffer size for best-fit packing
|
| 31 |
+
(default 1000)
|
| 32 |
+
"""
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
import random
|
| 37 |
+
from typing import Iterator
|
| 38 |
+
|
| 39 |
+
import numpy as np
|
| 40 |
+
import torch
|
| 41 |
+
import lightning as L
|
| 42 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| 43 |
+
|
| 44 |
+
import prepare as _prepare
|
| 45 |
+
import prepare_nemotron as _p_nemo
|
| 46 |
+
from prepare_nemotron import (
|
| 47 |
+
FULL_BLEND_WEIGHTS,
|
| 48 |
+
PHASE1_WEIGHTS,
|
| 49 |
+
PHASE2_WEIGHTS,
|
| 50 |
+
_BLEND_REGISTRY,
|
| 51 |
+
_extract_text,
|
| 52 |
+
_open_stream,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# Worker-local weighted stream. A stripped version of prepare_nemotron's
|
| 58 |
+
# _WeightedStream that is constructed inside each worker. Adds worker sharding:
|
| 59 |
+
# when num_workers > 1 the RNG is seeded per-worker, so different workers
|
| 60 |
+
# sample different config sequences and pull disjoint shard assignments from
|
| 61 |
+
# HF's shuffle buffer.
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class _WorkerWeightedStream:
|
| 66 |
+
def __init__(self, weights: dict[str, float], base_seed: int, worker_id: int):
|
| 67 |
+
self.configs = list(weights.keys())
|
| 68 |
+
self.weights = [weights[c] for c in self.configs]
|
| 69 |
+
self.base_seed = base_seed
|
| 70 |
+
self.worker_id = worker_id
|
| 71 |
+
# Each worker opens its own HF streams. _open_stream returns an iter()
|
| 72 |
+
# over a streaming dataset, with an internal shuffle buffer.
|
| 73 |
+
self.streams = {c: _open_stream(c, "train") for c in self.configs}
|
| 74 |
+
# Per-worker RNG so the config-choice trajectory is independent.
|
| 75 |
+
self.rng = random.Random(base_seed + worker_id * 7919)
|
| 76 |
+
self.epoch = 1
|
| 77 |
+
|
| 78 |
+
# Lazy-init factual docs (once per worker). The main-process version
|
| 79 |
+
# in prepare_nemotron._WeightedStream reads these on first __next__.
|
| 80 |
+
self._factual_docs: list[str] | None = None
|
| 81 |
+
self._factual_idx = 0
|
| 82 |
+
self._inject_counter = 0
|
| 83 |
+
inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
|
| 84 |
+
self._inject_rate = inject_rate
|
| 85 |
+
if inject_rate > 0:
|
| 86 |
+
factual_path = os.path.join(
|
| 87 |
+
os.path.dirname(os.path.abspath(_p_nemo.__file__)),
|
| 88 |
+
"data", "factual", "facts.txt",
|
| 89 |
+
)
|
| 90 |
+
if os.path.exists(factual_path):
|
| 91 |
+
with open(factual_path) as fh:
|
| 92 |
+
self._factual_docs = fh.read().strip().split("\n")
|
| 93 |
+
|
| 94 |
+
def _reopen(self, config: str) -> None:
|
| 95 |
+
self.streams[config] = _open_stream(config, "train")
|
| 96 |
+
self.epoch += 1
|
| 97 |
+
|
| 98 |
+
def __iter__(self):
|
| 99 |
+
return self
|
| 100 |
+
|
| 101 |
+
def __next__(self) -> tuple[str, int]:
|
| 102 |
+
# Factual injection (preserves prepare_nemotron cadence).
|
| 103 |
+
if self._inject_rate > 0 and self._factual_docs:
|
| 104 |
+
self._inject_counter += 1
|
| 105 |
+
if self._inject_counter >= self._inject_rate:
|
| 106 |
+
self._inject_counter = 0
|
| 107 |
+
doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
|
| 108 |
+
self._factual_idx += 1
|
| 109 |
+
return doc, self.epoch
|
| 110 |
+
|
| 111 |
+
config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
|
| 112 |
+
try:
|
| 113 |
+
row = next(self.streams[config])
|
| 114 |
+
except StopIteration:
|
| 115 |
+
self._reopen(config)
|
| 116 |
+
row = next(self.streams[config])
|
| 117 |
+
return _extract_text(row), self.epoch
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
# IterableStreamDataset — yields (T+1,) packed rows. No threads. No queues.
|
| 122 |
+
# Lives inside each DataLoader worker. DataLoader's own multiprocessing stacks
|
| 123 |
+
# rows into batches of shape (B, T+1) and sends them to the main process.
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class IterableStreamDataset(IterableDataset):
|
| 128 |
+
"""Streams docs, tokenizes, packs into (T+1,) rows via best-fit.
|
| 129 |
+
|
| 130 |
+
Each worker gets its own instance (via fork/spawn) and initializes its
|
| 131 |
+
own HF streams + rustbpe tokenizer + factual injector. The tokenizer
|
| 132 |
+
pickled blob is small (~1 MB) and thread-safe per tiktoken docs.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
split: str,
|
| 138 |
+
seq_len: int,
|
| 139 |
+
*,
|
| 140 |
+
base_seed: int = 0,
|
| 141 |
+
doc_buffer_size: int = 1000,
|
| 142 |
+
tokenizer_batch: int = 128,
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
assert split in ("train", "val"), split
|
| 146 |
+
self.split = split
|
| 147 |
+
self.seq_len = seq_len
|
| 148 |
+
self.row_capacity = seq_len + 1
|
| 149 |
+
self.base_seed = base_seed
|
| 150 |
+
self.doc_buffer_size = doc_buffer_size
|
| 151 |
+
self.tokenizer_batch = tokenizer_batch
|
| 152 |
+
|
| 153 |
+
def _pick_weights(self) -> dict[str, float]:
|
| 154 |
+
if self.split == "val":
|
| 155 |
+
if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
|
| 156 |
+
return FULL_BLEND_WEIGHTS
|
| 157 |
+
return {"Nemotron-Pretraining-Multiple-Choice": 1.0}
|
| 158 |
+
if os.environ.get("HYDRA_USE_FULL_BLEND", "0") == "1":
|
| 159 |
+
return FULL_BLEND_WEIGHTS
|
| 160 |
+
phase = os.environ.get("HYDRA_NEMOTRON_PHASE", "phase1").strip().lower()
|
| 161 |
+
return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
|
| 162 |
+
|
| 163 |
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
| 164 |
+
info = get_worker_info()
|
| 165 |
+
worker_id = 0 if info is None else info.id
|
| 166 |
+
|
| 167 |
+
# Each worker builds its own tokenizer instance. tiktoken's Encoding
|
| 168 |
+
# object is pickleable and the underlying C++ BPE is thread-safe;
|
| 169 |
+
# per-worker instantiation avoids cross-process sharing headaches.
|
| 170 |
+
tokenizer = _prepare.Tokenizer.from_directory()
|
| 171 |
+
bos = tokenizer.get_bos_token_id()
|
| 172 |
+
|
| 173 |
+
# Each worker gets its own weighted HF stream. Seed offset ensures
|
| 174 |
+
# disjoint config-choice trajectories; HF's own shuffle buffer handles
|
| 175 |
+
# shard randomization.
|
| 176 |
+
val_seed = 12345 # deterministic val
|
| 177 |
+
seed = val_seed if self.split == "val" else self.base_seed
|
| 178 |
+
stream = _WorkerWeightedStream(
|
| 179 |
+
self._pick_weights(), base_seed=seed, worker_id=worker_id,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
row_capacity = self.row_capacity
|
| 183 |
+
doc_buffer: list[list[int]] = []
|
| 184 |
+
doc_batch_size = self.tokenizer_batch
|
| 185 |
+
|
| 186 |
+
def refill_buffer() -> None:
|
| 187 |
+
# Collect doc_batch_size text strings, then batch-tokenize.
|
| 188 |
+
texts: list[str] = []
|
| 189 |
+
for _ in range(doc_batch_size):
|
| 190 |
+
text, _epoch = next(stream)
|
| 191 |
+
if text:
|
| 192 |
+
texts.append(text)
|
| 193 |
+
if texts:
|
| 194 |
+
token_lists = tokenizer.encode(texts, prepend=bos)
|
| 195 |
+
doc_buffer.extend(token_lists)
|
| 196 |
+
|
| 197 |
+
while True:
|
| 198 |
+
pos = 0
|
| 199 |
+
row = torch.empty(row_capacity, dtype=torch.long)
|
| 200 |
+
while pos < row_capacity:
|
| 201 |
+
while len(doc_buffer) < self.doc_buffer_size:
|
| 202 |
+
refill_buffer()
|
| 203 |
+
|
| 204 |
+
remaining = row_capacity - pos
|
| 205 |
+
|
| 206 |
+
# Best-fit packing: largest doc that fully fits.
|
| 207 |
+
best_idx = -1
|
| 208 |
+
best_len = 0
|
| 209 |
+
for i, doc in enumerate(doc_buffer):
|
| 210 |
+
dlen = len(doc)
|
| 211 |
+
if dlen <= remaining and dlen > best_len:
|
| 212 |
+
best_idx = i
|
| 213 |
+
best_len = dlen
|
| 214 |
+
|
| 215 |
+
if best_idx >= 0:
|
| 216 |
+
doc = doc_buffer.pop(best_idx)
|
| 217 |
+
row[pos : pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
|
| 218 |
+
pos += len(doc)
|
| 219 |
+
else:
|
| 220 |
+
# No doc fits remaining space — crop shortest to fill.
|
| 221 |
+
shortest_idx = min(
|
| 222 |
+
range(len(doc_buffer)),
|
| 223 |
+
key=lambda i: len(doc_buffer[i]),
|
| 224 |
+
)
|
| 225 |
+
doc = doc_buffer.pop(shortest_idx)
|
| 226 |
+
row[pos : pos + remaining] = torch.tensor(
|
| 227 |
+
doc[:remaining], dtype=torch.long,
|
| 228 |
+
)
|
| 229 |
+
pos += remaining
|
| 230 |
+
|
| 231 |
+
yield row
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ---------------------------------------------------------------------------
|
| 235 |
+
# LightningDataModule
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class HydraDataModule(L.LightningDataModule):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
batch_size: int | None = None,
|
| 243 |
+
seq_len: int | None = None,
|
| 244 |
+
num_workers: int | None = None,
|
| 245 |
+
prefetch_factor: int | None = None,
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.batch_size = batch_size or int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
|
| 249 |
+
self.seq_len = seq_len or int(os.environ.get("HYDRA_SEQ_LEN", "512"))
|
| 250 |
+
self.num_workers = (
|
| 251 |
+
num_workers
|
| 252 |
+
if num_workers is not None
|
| 253 |
+
else int(os.environ.get("HYDRA_DATA_NUM_WORKERS", "2"))
|
| 254 |
+
)
|
| 255 |
+
self.prefetch_factor = (
|
| 256 |
+
prefetch_factor
|
| 257 |
+
if prefetch_factor is not None
|
| 258 |
+
else int(os.environ.get("HYDRA_DATA_PREFETCH", "4"))
|
| 259 |
+
)
|
| 260 |
+
self.doc_buffer = int(os.environ.get("HYDRA_DATA_BUFFER", "1000"))
|
| 261 |
+
|
| 262 |
+
def _make_loader(self, split: str, seed: int) -> DataLoader:
|
| 263 |
+
dataset = IterableStreamDataset(
|
| 264 |
+
split=split,
|
| 265 |
+
seq_len=self.seq_len,
|
| 266 |
+
base_seed=seed,
|
| 267 |
+
doc_buffer_size=self.doc_buffer,
|
| 268 |
+
)
|
| 269 |
+
# num_workers=0 → main-process iteration (useful for debugging). With
|
| 270 |
+
# IterableDataset the DataLoader batches the rows into (B, T+1) via
|
| 271 |
+
# default torch.stack-collate.
|
| 272 |
+
kw: dict = dict(
|
| 273 |
+
dataset=dataset,
|
| 274 |
+
batch_size=self.batch_size,
|
| 275 |
+
num_workers=self.num_workers,
|
| 276 |
+
pin_memory=True,
|
| 277 |
+
drop_last=True,
|
| 278 |
+
)
|
| 279 |
+
if self.num_workers > 0:
|
| 280 |
+
kw["prefetch_factor"] = self.prefetch_factor
|
| 281 |
+
kw["persistent_workers"] = True
|
| 282 |
+
return DataLoader(**kw)
|
| 283 |
+
|
| 284 |
+
def train_dataloader(self) -> DataLoader:
|
| 285 |
+
return self._make_loader("train", seed=0)
|
| 286 |
+
|
| 287 |
+
def val_dataloader(self) -> DataLoader:
|
| 288 |
+
return self._make_loader("val", seed=12345)
|
overlay/hydra/diffusion_loss.py
CHANGED
|
@@ -1,236 +1,236 @@
|
|
| 1 |
-
"""MDLM Rao-Blackwellized Masked Diffusion Loss.
|
| 2 |
-
|
| 3 |
-
Implements the masked-diffusion ELBO from:
|
| 4 |
-
Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
|
| 5 |
-
NeurIPS 2024, arXiv:2406.07524.
|
| 6 |
-
|
| 7 |
-
Equations referenced:
|
| 8 |
-
- Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t)
|
| 9 |
-
- Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1)
|
| 10 |
-
- RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (1/alpha_t) * CE(x_theta(x_t), x_0) ]
|
| 11 |
-
where the expectation over masked positions.
|
| 12 |
-
|
| 13 |
-
Key insight: the Rao-Blackwellized estimate replaces an average over all masks
|
| 14 |
-
(exponential) by a closed-form weighted CE that applies weight 1/alpha_t only
|
| 15 |
-
on the positions that were masked, and 0 on unmasked positions. This gives an
|
| 16 |
-
unbiased estimator with lower variance than a naive Monte Carlo over mask
|
| 17 |
-
patterns.
|
| 18 |
-
|
| 19 |
-
Reference implementation cross-checked against:
|
| 20 |
-
https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss)
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
from __future__ import annotations
|
| 24 |
-
|
| 25 |
-
from typing import Literal
|
| 26 |
-
|
| 27 |
-
import torch
|
| 28 |
-
import torch.nn.functional as F
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# Clamping weight keeps gradients finite while still up-weighting high-noise
|
| 32 |
-
# positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
|
| 33 |
-
# launch (2026-04-22): loss 26 → 42 → NaN in 13 steps under Muon lr=7e-3
|
| 34 |
-
# because per-token CE × 1000 saturated the 100-unit FAIL guard. The MDLM
|
| 35 |
-
# paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
|
| 36 |
-
# (70× larger), so the weight clamp needs to compensate.
|
| 37 |
-
#
|
| 38 |
-
# Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
|
| 39 |
-
# weighting entirely (flat masked-LM CE, no RB reweighting — simpler and
|
| 40 |
-
# more stable, sacrifices the theoretical ELBO property).
|
| 41 |
-
import os as _os
|
| 42 |
-
_MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
|
| 43 |
-
_MIN_ALPHA: float = 1.0 / _MAX_WEIGHT # so clamp(alpha, min=_MIN_ALPHA) gives 1/alpha <= _MAX_WEIGHT
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# ---------------------------------------------------------------------------
|
| 47 |
-
# Public API
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
-
|
| 50 |
-
def mdlm_masked_forward_process(
|
| 51 |
-
targets: torch.Tensor,
|
| 52 |
-
mask_token_id: int,
|
| 53 |
-
t: torch.Tensor | None = None,
|
| 54 |
-
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
|
| 55 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 56 |
-
"""MDLM forward (noising) process: mask tokens and compute RB weights.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
targets: (B, T) int64 token ids — the clean sequence x_0.
|
| 60 |
-
mask_token_id: The special token id used to represent a masked token.
|
| 61 |
-
t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
|
| 62 |
-
element. t=0 means fully clean; t=1 means fully masked.
|
| 63 |
-
alpha_schedule: Noise schedule.
|
| 64 |
-
"loglinear" (MDLM default): alpha_t = 1 - t
|
| 65 |
-
"linear": identical formula — both are provided for completeness
|
| 66 |
-
since the paper calls the 1-t schedule "log-linear" in the context
|
| 67 |
-
of the ELBO derivation.
|
| 68 |
-
|
| 69 |
-
Returns:
|
| 70 |
-
x_t : (B, T) int64 — noised sequence; masked positions hold
|
| 71 |
-
mask_token_id, unmasked positions equal targets.
|
| 72 |
-
mask_positions: (B, T) bool — True where the token was masked.
|
| 73 |
-
loss_weights : (B, T) float32 — RB weighting factor. On masked
|
| 74 |
-
positions: 1/alpha_t (clamped to _MAX_WEIGHT). On
|
| 75 |
-
unmasked positions: 0.0. Summing
|
| 76 |
-
(CE * loss_weights * mask_positions).sum() / mask.sum()
|
| 77 |
-
gives the per-sample RB-ELBO estimator.
|
| 78 |
-
"""
|
| 79 |
-
B, T = targets.shape
|
| 80 |
-
device = targets.device
|
| 81 |
-
dtype = torch.float32
|
| 82 |
-
|
| 83 |
-
# --- sample or validate t ---
|
| 84 |
-
if t is None:
|
| 85 |
-
# Uniform(0, 1) per batch element; avoid exactly 0 and 1.
|
| 86 |
-
t = torch.rand(B, device=device, dtype=dtype)
|
| 87 |
-
else:
|
| 88 |
-
t = t.to(device=device, dtype=dtype)
|
| 89 |
-
if t.shape != (B,):
|
| 90 |
-
raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
|
| 91 |
-
if (t < 0).any() or (t > 1).any():
|
| 92 |
-
raise ValueError("t must be in [0, 1]")
|
| 93 |
-
|
| 94 |
-
# --- noise schedule: alpha_t = probability that a token is NOT masked ---
|
| 95 |
-
# Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
|
| 96 |
-
# refers to "log-linear" because the schedule is linear in the *log* domain
|
| 97 |
-
# of the forward process probability. We expose both names for clarity.
|
| 98 |
-
if alpha_schedule in ("linear", "loglinear"):
|
| 99 |
-
alpha_t = 1.0 - t # (B,) float, in [0, 1]
|
| 100 |
-
else:
|
| 101 |
-
raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")
|
| 102 |
-
|
| 103 |
-
# --- per-token Bernoulli mask ---
|
| 104 |
-
# alpha_t[:, None] broadcasts to (B, T).
|
| 105 |
-
alpha_t_expanded = alpha_t[:, None] # (B, 1)
|
| 106 |
-
# Bernoulli(1 - alpha_t) = 1 means "mask this token".
|
| 107 |
-
# We sample independently per token, per batch element.
|
| 108 |
-
rand = torch.rand(B, T, device=device, dtype=dtype)
|
| 109 |
-
mask_positions = rand > alpha_t_expanded # (B, T) bool
|
| 110 |
-
# True → masked position
|
| 111 |
-
# False → unmasked (kept as original)
|
| 112 |
-
|
| 113 |
-
# --- build x_t ---
|
| 114 |
-
x_t = targets.clone()
|
| 115 |
-
x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)
|
| 116 |
-
|
| 117 |
-
# --- RB loss weights: 1/alpha_t on masked positions, 0 elsewhere ---
|
| 118 |
-
# Clamp alpha_t so weights stay finite near t→1.
|
| 119 |
-
safe_alpha = alpha_t.clamp(min=_MIN_ALPHA) # (B,)
|
| 120 |
-
weight_per_sample = 1.0 / safe_alpha # (B,)
|
| 121 |
-
# Broadcast to (B, T) and zero out unmasked positions.
|
| 122 |
-
loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T)
|
| 123 |
-
loss_weights = loss_weights * mask_positions.float()
|
| 124 |
-
|
| 125 |
-
return x_t, mask_positions, loss_weights
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def mdlm_rb_loss(
|
| 129 |
-
logits: torch.Tensor,
|
| 130 |
-
targets: torch.Tensor,
|
| 131 |
-
mask_positions: torch.Tensor,
|
| 132 |
-
loss_weights: torch.Tensor,
|
| 133 |
-
ignore_index: int = -100,
|
| 134 |
-
) -> torch.Tensor:
|
| 135 |
-
"""Rao-Blackwellized negative ELBO.
|
| 136 |
-
|
| 137 |
-
Applies the MDLM loss: cross-entropy on masked positions only, weighted
|
| 138 |
-
per-token by loss_weights, averaged over the batch.
|
| 139 |
-
|
| 140 |
-
The formula (eq. 7-8 of arXiv:2406.07524):
|
| 141 |
-
L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
|
| 142 |
-
/ max(sum_T(mask_i), 1) ]
|
| 143 |
-
|
| 144 |
-
Args:
|
| 145 |
-
logits : (B, T, V) raw logits. May be bf16; internally cast to
|
| 146 |
-
float32 for CE computation.
|
| 147 |
-
targets : (B, T) int64 true token ids (x_0).
|
| 148 |
-
mask_positions: (B, T) bool — True = masked position.
|
| 149 |
-
loss_weights : (B, T) float32 — 1/alpha_t on masked positions, 0 elsewhere.
|
| 150 |
-
ignore_index : Passed to F.cross_entropy; positions with this label
|
| 151 |
-
are excluded from the loss.
|
| 152 |
-
|
| 153 |
-
Returns:
|
| 154 |
-
Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
|
| 155 |
-
"""
|
| 156 |
-
B, T, V = logits.shape
|
| 157 |
-
|
| 158 |
-
# Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
|
| 159 |
-
# logits but accumulates in float internally anyway. Being explicit avoids
|
| 160 |
-
# silent precision surprises.
|
| 161 |
-
logits_f = logits.float() # (B, T, V)
|
| 162 |
-
|
| 163 |
-
# Build targets with ignore_index on UNmasked positions so CE only fires
|
| 164 |
-
# where mask_positions is True. We also honour any pre-existing -100 values
|
| 165 |
-
# (e.g. doc-separator masking upstream).
|
| 166 |
-
targets_masked = torch.where(
|
| 167 |
-
mask_positions & (targets != ignore_index),
|
| 168 |
-
targets,
|
| 169 |
-
torch.full_like(targets, ignore_index),
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
# Per-token CE; shape (B, T). Positions with ignore_index → 0 from CE.
|
| 173 |
-
per_tok_ce = F.cross_entropy(
|
| 174 |
-
logits_f.reshape(B * T, V),
|
| 175 |
-
targets_masked.reshape(B * T),
|
| 176 |
-
ignore_index=ignore_index,
|
| 177 |
-
reduction="none",
|
| 178 |
-
).reshape(B, T) # (B, T) float32
|
| 179 |
-
|
| 180 |
-
# Apply RB weight. loss_weights already has 0 on unmasked positions.
|
| 181 |
-
weighted = per_tok_ce * loss_weights # (B, T)
|
| 182 |
-
|
| 183 |
-
# Per-sample mean over masked positions, then average over batch.
|
| 184 |
-
mask_f = mask_positions.float() # (B, T)
|
| 185 |
-
per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,)
|
| 186 |
-
per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,)
|
| 187 |
-
|
| 188 |
-
return per_sample_loss.mean() # scalar float32
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def mdlm_loss(
|
| 192 |
-
logits: torch.Tensor,
|
| 193 |
-
targets: torch.Tensor,
|
| 194 |
-
mask_token_id: int,
|
| 195 |
-
t: torch.Tensor | None = None,
|
| 196 |
-
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
|
| 197 |
-
ignore_index: int = -100,
|
| 198 |
-
) -> torch.Tensor:
|
| 199 |
-
"""Convenience wrapper: forward process + RB-ELBO in one call.
|
| 200 |
-
|
| 201 |
-
Suitable for the common case where the caller has full-vocab logits and
|
| 202 |
-
wants a drop-in replacement for a standard masked-LM CE loss.
|
| 203 |
-
|
| 204 |
-
Args:
|
| 205 |
-
logits : (B, T, V) raw logits.
|
| 206 |
-
targets : (B, T) int64 clean token ids.
|
| 207 |
-
mask_token_id : The MASK token id used to corrupt the input.
|
| 208 |
-
t : Optional (B,) timestep in (0, 1). Sampled if None.
|
| 209 |
-
alpha_schedule: "loglinear" (default) or "linear".
|
| 210 |
-
ignore_index : Token id to ignore in the loss (e.g. padding).
|
| 211 |
-
|
| 212 |
-
Returns:
|
| 213 |
-
Scalar float32 MDLM RB-ELBO loss.
|
| 214 |
-
|
| 215 |
-
Note on sampled-softmax / partial logits:
|
| 216 |
-
If your model only computes logits for a subset of vocab positions
|
| 217 |
-
(e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
|
| 218 |
-
and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
|
| 219 |
-
"""
|
| 220 |
-
x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
|
| 221 |
-
targets=targets,
|
| 222 |
-
mask_token_id=mask_token_id,
|
| 223 |
-
t=t,
|
| 224 |
-
alpha_schedule=alpha_schedule,
|
| 225 |
-
)
|
| 226 |
-
# x_t is produced for the model's input (not used by this convenience
|
| 227 |
-
# wrapper since logits are already provided by the caller). In a real
|
| 228 |
-
# training loop the caller feeds x_t into the model to get logits, THEN
|
| 229 |
-
# calls this function. See the orchestrator wiring note in training.py.
|
| 230 |
-
return mdlm_rb_loss(
|
| 231 |
-
logits=logits,
|
| 232 |
-
targets=targets,
|
| 233 |
-
mask_positions=mask_positions,
|
| 234 |
-
loss_weights=loss_weights,
|
| 235 |
-
ignore_index=ignore_index,
|
| 236 |
-
)
|
|
|
|
| 1 |
+
"""MDLM Rao-Blackwellized Masked Diffusion Loss.
|
| 2 |
+
|
| 3 |
+
Implements the masked-diffusion ELBO from:
|
| 4 |
+
Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
|
| 5 |
+
NeurIPS 2024, arXiv:2406.07524.
|
| 6 |
+
|
| 7 |
+
Equations referenced:
|
| 8 |
+
- Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t)
|
| 9 |
+
- Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1)
|
| 10 |
+
- RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (1/alpha_t) * CE(x_theta(x_t), x_0) ]
|
| 11 |
+
where the expectation over masked positions.
|
| 12 |
+
|
| 13 |
+
Key insight: the Rao-Blackwellized estimate replaces an average over all masks
|
| 14 |
+
(exponential) by a closed-form weighted CE that applies weight 1/alpha_t only
|
| 15 |
+
on the positions that were masked, and 0 on unmasked positions. This gives an
|
| 16 |
+
unbiased estimator with lower variance than a naive Monte Carlo over mask
|
| 17 |
+
patterns.
|
| 18 |
+
|
| 19 |
+
Reference implementation cross-checked against:
|
| 20 |
+
https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
from typing import Literal
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Clamping weight keeps gradients finite while still up-weighting high-noise
|
| 32 |
+
# positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
|
| 33 |
+
# launch (2026-04-22): loss 26 → 42 → NaN in 13 steps under Muon lr=7e-3
|
| 34 |
+
# because per-token CE × 1000 saturated the 100-unit FAIL guard. The MDLM
|
| 35 |
+
# paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
|
| 36 |
+
# (70× larger), so the weight clamp needs to compensate.
|
| 37 |
+
#
|
| 38 |
+
# Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
|
| 39 |
+
# weighting entirely (flat masked-LM CE, no RB reweighting — simpler and
|
| 40 |
+
# more stable, sacrifices the theoretical ELBO property).
|
| 41 |
+
import os as _os
|
| 42 |
+
_MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
|
| 43 |
+
_MIN_ALPHA: float = 1.0 / _MAX_WEIGHT # so clamp(alpha, min=_MIN_ALPHA) gives 1/alpha <= _MAX_WEIGHT
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Public API
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def mdlm_masked_forward_process(
|
| 51 |
+
targets: torch.Tensor,
|
| 52 |
+
mask_token_id: int,
|
| 53 |
+
t: torch.Tensor | None = None,
|
| 54 |
+
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
|
| 55 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 56 |
+
"""MDLM forward (noising) process: mask tokens and compute RB weights.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
targets: (B, T) int64 token ids — the clean sequence x_0.
|
| 60 |
+
mask_token_id: The special token id used to represent a masked token.
|
| 61 |
+
t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
|
| 62 |
+
element. t=0 means fully clean; t=1 means fully masked.
|
| 63 |
+
alpha_schedule: Noise schedule.
|
| 64 |
+
"loglinear" (MDLM default): alpha_t = 1 - t
|
| 65 |
+
"linear": identical formula — both are provided for completeness
|
| 66 |
+
since the paper calls the 1-t schedule "log-linear" in the context
|
| 67 |
+
of the ELBO derivation.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
x_t : (B, T) int64 — noised sequence; masked positions hold
|
| 71 |
+
mask_token_id, unmasked positions equal targets.
|
| 72 |
+
mask_positions: (B, T) bool — True where the token was masked.
|
| 73 |
+
loss_weights : (B, T) float32 — RB weighting factor. On masked
|
| 74 |
+
positions: 1/alpha_t (clamped to _MAX_WEIGHT). On
|
| 75 |
+
unmasked positions: 0.0. Summing
|
| 76 |
+
(CE * loss_weights * mask_positions).sum() / mask.sum()
|
| 77 |
+
gives the per-sample RB-ELBO estimator.
|
| 78 |
+
"""
|
| 79 |
+
B, T = targets.shape
|
| 80 |
+
device = targets.device
|
| 81 |
+
dtype = torch.float32
|
| 82 |
+
|
| 83 |
+
# --- sample or validate t ---
|
| 84 |
+
if t is None:
|
| 85 |
+
# Uniform(0, 1) per batch element; avoid exactly 0 and 1.
|
| 86 |
+
t = torch.rand(B, device=device, dtype=dtype)
|
| 87 |
+
else:
|
| 88 |
+
t = t.to(device=device, dtype=dtype)
|
| 89 |
+
if t.shape != (B,):
|
| 90 |
+
raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
|
| 91 |
+
if (t < 0).any() or (t > 1).any():
|
| 92 |
+
raise ValueError("t must be in [0, 1]")
|
| 93 |
+
|
| 94 |
+
# --- noise schedule: alpha_t = probability that a token is NOT masked ---
|
| 95 |
+
# Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
|
| 96 |
+
# refers to "log-linear" because the schedule is linear in the *log* domain
|
| 97 |
+
# of the forward process probability. We expose both names for clarity.
|
| 98 |
+
if alpha_schedule in ("linear", "loglinear"):
|
| 99 |
+
alpha_t = 1.0 - t # (B,) float, in [0, 1]
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")
|
| 102 |
+
|
| 103 |
+
# --- per-token Bernoulli mask ---
|
| 104 |
+
# alpha_t[:, None] broadcasts to (B, T).
|
| 105 |
+
alpha_t_expanded = alpha_t[:, None] # (B, 1)
|
| 106 |
+
# Bernoulli(1 - alpha_t) = 1 means "mask this token".
|
| 107 |
+
# We sample independently per token, per batch element.
|
| 108 |
+
rand = torch.rand(B, T, device=device, dtype=dtype)
|
| 109 |
+
mask_positions = rand > alpha_t_expanded # (B, T) bool
|
| 110 |
+
# True → masked position
|
| 111 |
+
# False → unmasked (kept as original)
|
| 112 |
+
|
| 113 |
+
# --- build x_t ---
|
| 114 |
+
x_t = targets.clone()
|
| 115 |
+
x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)
|
| 116 |
+
|
| 117 |
+
# --- RB loss weights: 1/alpha_t on masked positions, 0 elsewhere ---
|
| 118 |
+
# Clamp alpha_t so weights stay finite near t→1.
|
| 119 |
+
safe_alpha = alpha_t.clamp(min=_MIN_ALPHA) # (B,)
|
| 120 |
+
weight_per_sample = 1.0 / safe_alpha # (B,)
|
| 121 |
+
# Broadcast to (B, T) and zero out unmasked positions.
|
| 122 |
+
loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T)
|
| 123 |
+
loss_weights = loss_weights * mask_positions.float()
|
| 124 |
+
|
| 125 |
+
return x_t, mask_positions, loss_weights
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def mdlm_rb_loss(
|
| 129 |
+
logits: torch.Tensor,
|
| 130 |
+
targets: torch.Tensor,
|
| 131 |
+
mask_positions: torch.Tensor,
|
| 132 |
+
loss_weights: torch.Tensor,
|
| 133 |
+
ignore_index: int = -100,
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
"""Rao-Blackwellized negative ELBO.
|
| 136 |
+
|
| 137 |
+
Applies the MDLM loss: cross-entropy on masked positions only, weighted
|
| 138 |
+
per-token by loss_weights, averaged over the batch.
|
| 139 |
+
|
| 140 |
+
The formula (eq. 7-8 of arXiv:2406.07524):
|
| 141 |
+
L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
|
| 142 |
+
/ max(sum_T(mask_i), 1) ]
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
logits : (B, T, V) raw logits. May be bf16; internally cast to
|
| 146 |
+
float32 for CE computation.
|
| 147 |
+
targets : (B, T) int64 true token ids (x_0).
|
| 148 |
+
mask_positions: (B, T) bool — True = masked position.
|
| 149 |
+
loss_weights : (B, T) float32 — 1/alpha_t on masked positions, 0 elsewhere.
|
| 150 |
+
ignore_index : Passed to F.cross_entropy; positions with this label
|
| 151 |
+
are excluded from the loss.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
|
| 155 |
+
"""
|
| 156 |
+
B, T, V = logits.shape
|
| 157 |
+
|
| 158 |
+
# Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
|
| 159 |
+
# logits but accumulates in float internally anyway. Being explicit avoids
|
| 160 |
+
# silent precision surprises.
|
| 161 |
+
logits_f = logits.float() # (B, T, V)
|
| 162 |
+
|
| 163 |
+
# Build targets with ignore_index on UNmasked positions so CE only fires
|
| 164 |
+
# where mask_positions is True. We also honour any pre-existing -100 values
|
| 165 |
+
# (e.g. doc-separator masking upstream).
|
| 166 |
+
targets_masked = torch.where(
|
| 167 |
+
mask_positions & (targets != ignore_index),
|
| 168 |
+
targets,
|
| 169 |
+
torch.full_like(targets, ignore_index),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Per-token CE; shape (B, T). Positions with ignore_index → 0 from CE.
|
| 173 |
+
per_tok_ce = F.cross_entropy(
|
| 174 |
+
logits_f.reshape(B * T, V),
|
| 175 |
+
targets_masked.reshape(B * T),
|
| 176 |
+
ignore_index=ignore_index,
|
| 177 |
+
reduction="none",
|
| 178 |
+
).reshape(B, T) # (B, T) float32
|
| 179 |
+
|
| 180 |
+
# Apply RB weight. loss_weights already has 0 on unmasked positions.
|
| 181 |
+
weighted = per_tok_ce * loss_weights # (B, T)
|
| 182 |
+
|
| 183 |
+
# Per-sample mean over masked positions, then average over batch.
|
| 184 |
+
mask_f = mask_positions.float() # (B, T)
|
| 185 |
+
per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,)
|
| 186 |
+
per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,)
|
| 187 |
+
|
| 188 |
+
return per_sample_loss.mean() # scalar float32
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def mdlm_loss(
|
| 192 |
+
logits: torch.Tensor,
|
| 193 |
+
targets: torch.Tensor,
|
| 194 |
+
mask_token_id: int,
|
| 195 |
+
t: torch.Tensor | None = None,
|
| 196 |
+
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
|
| 197 |
+
ignore_index: int = -100,
|
| 198 |
+
) -> torch.Tensor:
|
| 199 |
+
"""Convenience wrapper: forward process + RB-ELBO in one call.
|
| 200 |
+
|
| 201 |
+
Suitable for the common case where the caller has full-vocab logits and
|
| 202 |
+
wants a drop-in replacement for a standard masked-LM CE loss.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
logits : (B, T, V) raw logits.
|
| 206 |
+
targets : (B, T) int64 clean token ids.
|
| 207 |
+
mask_token_id : The MASK token id used to corrupt the input.
|
| 208 |
+
t : Optional (B,) timestep in (0, 1). Sampled if None.
|
| 209 |
+
alpha_schedule: "loglinear" (default) or "linear".
|
| 210 |
+
ignore_index : Token id to ignore in the loss (e.g. padding).
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Scalar float32 MDLM RB-ELBO loss.
|
| 214 |
+
|
| 215 |
+
Note on sampled-softmax / partial logits:
|
| 216 |
+
If your model only computes logits for a subset of vocab positions
|
| 217 |
+
(e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
|
| 218 |
+
and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
|
| 219 |
+
"""
|
| 220 |
+
x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
|
| 221 |
+
targets=targets,
|
| 222 |
+
mask_token_id=mask_token_id,
|
| 223 |
+
t=t,
|
| 224 |
+
alpha_schedule=alpha_schedule,
|
| 225 |
+
)
|
| 226 |
+
# x_t is produced for the model's input (not used by this convenience
|
| 227 |
+
# wrapper since logits are already provided by the caller). In a real
|
| 228 |
+
# training loop the caller feeds x_t into the model to get logits, THEN
|
| 229 |
+
# calls this function. See the orchestrator wiring note in training.py.
|
| 230 |
+
return mdlm_rb_loss(
|
| 231 |
+
logits=logits,
|
| 232 |
+
targets=targets,
|
| 233 |
+
mask_positions=mask_positions,
|
| 234 |
+
loss_weights=loss_weights,
|
| 235 |
+
ignore_index=ignore_index,
|
| 236 |
+
)
|
overlay/hydra/engram.py
CHANGED
|
@@ -1,175 +1,160 @@
|
|
| 1 |
-
"""GPU Engram — Top-k Sparse Hopfield retrieval
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
self,
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
) ->
|
| 87 |
-
|
| 88 |
-
self.
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
if self.training and self.hebbian_boost:
|
| 162 |
-
with torch.no_grad():
|
| 163 |
-
indices = self._hash(token_ids)
|
| 164 |
-
flat_idx = indices.reshape(-1) # (B*T,)
|
| 165 |
-
flat_x = x.detach().reshape(-1, D) # (B*T, d_model)
|
| 166 |
-
mem_dtype = self.memory.data.dtype
|
| 167 |
-
updates = (
|
| 168 |
-
self.hebbian_lr * flat_x
|
| 169 |
-
- self.hebbian_lr * self.memory.data[flat_idx]
|
| 170 |
-
).to(mem_dtype)
|
| 171 |
-
self.memory.data.index_add_(0, flat_idx, updates)
|
| 172 |
-
|
| 173 |
-
# ---- 7. Residual + hit_rate -------------------------------------
|
| 174 |
-
hit_rate = (alpha.detach() > 0.1).float().mean()
|
| 175 |
-
return x + alpha * retrieved, hit_rate
|
|
|
|
| 1 |
+
"""GPU Engram — Top-k Sparse Hopfield retrieval with optional Cantor/SDR nerve constraint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
_ENGRAM_TOPK = int(os.environ.get("HYDRA_ENGRAM_TOPK", "64"))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GPUEngram(nn.Module):
|
| 15 |
+
"""GPU Engram: Top-k Sparse Hopfield retrieval.
|
| 16 |
+
|
| 17 |
+
Default `routing_mode=flat` preserves the existing full-memory top-k path.
|
| 18 |
+
`cantor_sdr` constrains candidates to the current Cantor leaf shard and SDR
|
| 19 |
+
active offsets. `auto` only uses that local path when it is cheaper than the
|
| 20 |
+
full score matrix (`K * d_model < n_columns`).
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
d_model: int,
|
| 26 |
+
n_columns: int = 1024,
|
| 27 |
+
max_ngram: int = 3,
|
| 28 |
+
hebbian_boost: bool = False,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.n_columns = n_columns
|
| 32 |
+
self.max_ngram = max_ngram
|
| 33 |
+
self.hebbian_boost = hebbian_boost
|
| 34 |
+
self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
|
| 35 |
+
self.gate = nn.Linear(d_model, 1, bias=True)
|
| 36 |
+
nn.init.constant_(self.gate.bias, 0.0)
|
| 37 |
+
self.topk_k = min(_ENGRAM_TOPK, n_columns)
|
| 38 |
+
self.primes = [2654435761, 2246822519, 3266489917]
|
| 39 |
+
self.hebbian_lr = 0.01
|
| 40 |
+
self.routing_mode = os.environ.get("HYDRA_ENGRAM_ROUTING", "auto").lower()
|
| 41 |
+
|
| 42 |
+
def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
B, T = token_ids.shape
|
| 44 |
+
h = token_ids * self.primes[0]
|
| 45 |
+
if T > 1:
|
| 46 |
+
shifted1 = torch.roll(token_ids, 1, dims=1)
|
| 47 |
+
shifted1[:, 0] = 0
|
| 48 |
+
h = h ^ (shifted1 * self.primes[1])
|
| 49 |
+
if T > 2:
|
| 50 |
+
shifted2 = torch.roll(token_ids, 2, dims=1)
|
| 51 |
+
shifted2[:, :2] = 0
|
| 52 |
+
h = h ^ (shifted2 * self.primes[2])
|
| 53 |
+
return h % self.n_columns
|
| 54 |
+
|
| 55 |
+
def _validate_active_indices(self, sdr_active_indices: torch.Tensor, x: torch.Tensor) -> None:
|
| 56 |
+
if not torch.is_floating_point(sdr_active_indices) and sdr_active_indices.dtype != torch.bool:
|
| 57 |
+
pass
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
|
| 60 |
+
if sdr_active_indices.dim() not in (2, 3):
|
| 61 |
+
raise ValueError("compact active indices must have shape (B,T,K) or (B*T,K)")
|
| 62 |
+
# Dense SDR masks arrive with K ~= n_bits; compact buffers are small
|
| 63 |
+
# (retina target_active or RealityBridge l0_k). Refuse obviously dense
|
| 64 |
+
# masks so forced cantor_sdr cannot silently route 0/1 values as offsets.
|
| 65 |
+
if sdr_active_indices.shape[-1] > 1024 or sdr_active_indices.shape[-1] > self.n_columns:
|
| 66 |
+
raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
|
| 67 |
+
|
| 68 |
+
def _cantor_sdr_candidates(
|
| 69 |
+
self,
|
| 70 |
+
sdr_active_indices: torch.Tensor,
|
| 71 |
+
cantor_leaf_ids: torch.Tensor,
|
| 72 |
+
n_leaves: int,
|
| 73 |
+
) -> torch.Tensor:
|
| 74 |
+
"""Map SDR active offsets into each Cantor leaf's Engram column shard."""
|
| 75 |
+
self._validate_active_indices(sdr_active_indices, cantor_leaf_ids)
|
| 76 |
+
if sdr_active_indices.dim() == 2:
|
| 77 |
+
B, T = cantor_leaf_ids.shape
|
| 78 |
+
sdr_active_indices = sdr_active_indices.view(B, T, -1)
|
| 79 |
+
sdr = sdr_active_indices.to(device=cantor_leaf_ids.device, dtype=torch.long)
|
| 80 |
+
leaves = cantor_leaf_ids.to(dtype=torch.long).clamp(min=0, max=max(0, n_leaves - 1))
|
| 81 |
+
cols_per_leaf = max(1, self.n_columns // max(1, n_leaves))
|
| 82 |
+
offsets = sdr.remainder(cols_per_leaf)
|
| 83 |
+
base = leaves.unsqueeze(-1) * cols_per_leaf
|
| 84 |
+
return (base + offsets).clamp(max=self.n_columns - 1)
|
| 85 |
+
|
| 86 |
+
def _flat_retrieve(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
scores = x @ self.memory.T
|
| 88 |
+
topk_vals, topk_idx = scores.topk(self.topk_k, dim=-1)
|
| 89 |
+
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 90 |
+
selected_mem = self.memory[topk_idx]
|
| 91 |
+
return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
|
| 92 |
+
|
| 93 |
+
def _cantor_sdr_retrieve(
|
| 94 |
+
self,
|
| 95 |
+
x: torch.Tensor,
|
| 96 |
+
sdr_active_indices: torch.Tensor,
|
| 97 |
+
cantor_leaf_ids: torch.Tensor,
|
| 98 |
+
cantor_n_leaves: int,
|
| 99 |
+
) -> torch.Tensor:
|
| 100 |
+
candidates = self._cantor_sdr_candidates(
|
| 101 |
+
sdr_active_indices,
|
| 102 |
+
cantor_leaf_ids,
|
| 103 |
+
n_leaves=cantor_n_leaves,
|
| 104 |
+
)
|
| 105 |
+
cand_mem = self.memory[candidates]
|
| 106 |
+
scores = torch.einsum('btd,btkd->btk', x, cand_mem)
|
| 107 |
+
k = min(self.topk_k, scores.shape[-1])
|
| 108 |
+
topk_vals, local_idx = scores.topk(k, dim=-1)
|
| 109 |
+
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 110 |
+
global_idx = candidates.gather(-1, local_idx)
|
| 111 |
+
selected_mem = self.memory[global_idx]
|
| 112 |
+
return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
|
| 113 |
+
|
| 114 |
+
def forward(
|
| 115 |
+
self,
|
| 116 |
+
x: torch.Tensor,
|
| 117 |
+
token_ids: torch.Tensor,
|
| 118 |
+
sdr_active_indices: torch.Tensor | None = None,
|
| 119 |
+
cantor_leaf_ids: torch.Tensor | None = None,
|
| 120 |
+
cantor_n_leaves: int | None = None,
|
| 121 |
+
):
|
| 122 |
+
B, T, D = x.shape
|
| 123 |
+
mode = self.routing_mode
|
| 124 |
+
use_cantor = (
|
| 125 |
+
mode in {"cantor_sdr", "auto"}
|
| 126 |
+
and sdr_active_indices is not None
|
| 127 |
+
and cantor_leaf_ids is not None
|
| 128 |
+
and cantor_n_leaves is not None
|
| 129 |
+
)
|
| 130 |
+
if mode == "auto" and use_cantor:
|
| 131 |
+
k_active = sdr_active_indices.shape[-1]
|
| 132 |
+
# Compare actual retrieval candidates against the full-memory scan.
|
| 133 |
+
# The previous `(k_active * D) < n_columns` check mixed candidate
|
| 134 |
+
# count with feature dimension, so d256/k64 fell back to flat
|
| 135 |
+
# retrieval even though Cantor/SDR scores only 64 candidates vs
|
| 136 |
+
# 8k-16k memory columns. That kept required subsystems active but
|
| 137 |
+
# spent tens of billions of extra MACs per forward.
|
| 138 |
+
use_cantor = k_active < self.n_columns
|
| 139 |
+
|
| 140 |
+
if use_cantor and mode in {"cantor_sdr", "auto"}:
|
| 141 |
+
retrieved = self._cantor_sdr_retrieve(x, sdr_active_indices, cantor_leaf_ids, cantor_n_leaves)
|
| 142 |
+
else:
|
| 143 |
+
retrieved = self._flat_retrieve(x)
|
| 144 |
+
|
| 145 |
+
alpha = torch.sigmoid(self.gate(x))
|
| 146 |
+
|
| 147 |
+
if self.training and self.hebbian_boost:
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
indices = self._hash(token_ids)
|
| 150 |
+
flat_idx = indices.reshape(-1)
|
| 151 |
+
flat_x = x.detach().reshape(-1, D)
|
| 152 |
+
mem_dtype = self.memory.data.dtype
|
| 153 |
+
updates = (
|
| 154 |
+
self.hebbian_lr * flat_x
|
| 155 |
+
- self.hebbian_lr * self.memory.data[flat_idx]
|
| 156 |
+
).to(mem_dtype)
|
| 157 |
+
self.memory.data.index_add_(0, flat_idx, updates)
|
| 158 |
+
|
| 159 |
+
hit_rate = (alpha.detach() > 0.1).float().mean()
|
| 160 |
+
return x + alpha * retrieved, hit_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/hydra/eval.py
CHANGED
|
@@ -1,217 +1,210 @@
|
|
| 1 |
-
"""Evaluation: factual probes + sampled factual English scoring.
|
| 2 |
-
|
| 3 |
-
Extracted from train.py (W1 modularization). Semantics unchanged.
|
| 4 |
-
|
| 5 |
-
Perf optimizations (eval_perf_fix):
|
| 6 |
-
- Probe mode: single forward per prompt instead of autoregressive gen
|
| 7 |
-
- Batch decode: all GPU work first, all CPU decode after
|
| 8 |
-
- Batched factual probes: single padded forward instead of N sequential
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from __future__ import annotations
|
| 12 |
-
|
| 13 |
-
import os
|
| 14 |
-
import re as _re
|
| 15 |
-
|
| 16 |
-
import torch
|
| 17 |
-
|
| 18 |
-
from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS
|
| 19 |
-
|
| 20 |
-
# Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
|
| 21 |
-
# the original autoregressive generation path.
|
| 22 |
-
FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
|
| 23 |
-
|
| 24 |
-
FACTUAL_EVAL = [
|
| 25 |
-
# Hard factual recall — requires specific knowledge memorization
|
| 26 |
-
("The capital of France is", ["Paris", "paris"]),
|
| 27 |
-
("Water boils at", ["100", "boiling"]),
|
| 28 |
-
("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
|
| 29 |
-
# Easier completions — common collocations / patterns the model may pick up
|
| 30 |
-
("Once upon a", ["time"]),
|
| 31 |
-
("Hello, my name", ["is", "'s"]),
|
| 32 |
-
("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
|
| 33 |
-
("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
|
| 34 |
-
# Original hard ones kept for completeness
|
| 35 |
-
("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
|
| 36 |
-
("Two plus two equals", ["4", "four"]),
|
| 37 |
-
]
|
| 38 |
-
|
| 39 |
-
_FACTUAL_PROBES = [
|
| 40 |
-
"The capital of France is",
|
| 41 |
-
"Water boils at",
|
| 42 |
-
"The largest planet in our solar system is",
|
| 43 |
-
"The speed of light is approximately",
|
| 44 |
-
"Shakespeare wrote",
|
| 45 |
-
]
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
|
| 49 |
-
"""Top-5 next-token predictions for canonical factual prompts.
|
| 50 |
-
|
| 51 |
-
Batched: pads all prompts into a single forward pass instead of N
|
| 52 |
-
sequential passes.
|
| 53 |
-
"""
|
| 54 |
-
print("\n--- Factual Probes ---")
|
| 55 |
-
model.eval()
|
| 56 |
-
|
| 57 |
-
# Process probes one at a time to avoid cooperative launch limit
|
| 58 |
-
# (batched forward with B=len(probes) can exceed SM residency cap).
|
| 59 |
-
for prompt_text in _FACTUAL_PROBES:
|
| 60 |
-
ids = tokenizer.encode(prompt_text)
|
| 61 |
-
x = torch.tensor([ids], device=device)
|
| 62 |
-
with torch.no_grad(), autocast_ctx:
|
| 63 |
-
logits = model(x)
|
| 64 |
-
probs = torch.softmax(logits[0, -1].float(), dim=-1)
|
| 65 |
-
top5 = torch.topk(probs, 5)
|
| 66 |
-
completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
|
| 67 |
-
probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
|
| 68 |
-
print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
|
| 69 |
-
print("--- End Factual Probes ---\n")
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# ---------------------------------------------------------------------------
|
| 73 |
-
# Probe mode: single forward per prompt (Fix D)
|
| 74 |
-
# ---------------------------------------------------------------------------
|
| 75 |
-
|
| 76 |
-
def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
|
| 77 |
-
"""Fast probe mode: for each (prompt, answers), encode prompt + each answer
|
| 78 |
-
candidate as a single sequence, do ONE forward pass, and check if the model's
|
| 79 |
-
argmax at the last prompt token matches the first answer token.
|
| 80 |
-
|
| 81 |
-
Falls back to checking top-K predictions to be generous (same as gen mode
|
| 82 |
-
which samples multiple temperatures).
|
| 83 |
-
"""
|
| 84 |
-
print("---")
|
| 85 |
-
print("factual_english_samples: (probe mode)")
|
| 86 |
-
model.eval()
|
| 87 |
-
hits = 0
|
| 88 |
-
|
| 89 |
-
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 90 |
-
for prompt, answers in FACTUAL_EVAL:
|
| 91 |
-
prompt_ids = tokenizer.encode(prompt)
|
| 92 |
-
prompt_len = len(prompt_ids)
|
| 93 |
-
x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
|
| 94 |
-
logits = model(x, targets=None)
|
| 95 |
-
# logits shape: [1, seq_len, vocab] or [1, vocab]
|
| 96 |
-
if logits.dim() == 3:
|
| 97 |
-
last_logits = logits[0, -1, :]
|
| 98 |
-
else:
|
| 99 |
-
last_logits = logits[0]
|
| 100 |
-
|
| 101 |
-
probs = torch.softmax(last_logits.float(), dim=-1)
|
| 102 |
-
# Check top-K predictions (generous: K=20 to match multi-sample gen)
|
| 103 |
-
top_k = min(20, probs.shape[-1])
|
| 104 |
-
top_ids = torch.topk(probs, top_k).indices.tolist()
|
| 105 |
-
top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
|
| 106 |
-
|
| 107 |
-
answers_lower = [a.lower() for a in answers]
|
| 108 |
-
any_hit = any(
|
| 109 |
-
any(a in tok for a in answers_lower)
|
| 110 |
-
for tok in top_tokens
|
| 111 |
-
)
|
| 112 |
-
if any_hit:
|
| 113 |
-
hits += 1
|
| 114 |
-
|
| 115 |
-
best_completion = tokenizer.decode([top_ids[0]])
|
| 116 |
-
print(f" prompt: {prompt!r}")
|
| 117 |
-
print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
|
| 118 |
-
print(f" hit: {any_hit} (probe top-{top_k})")
|
| 119 |
-
|
| 120 |
-
score = hits / len(FACTUAL_EVAL)
|
| 121 |
-
print("---")
|
| 122 |
-
print(f"factual_english_score: {score:.4f}")
|
| 123 |
-
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
|
| 124 |
-
return score, hits, len(FACTUAL_EVAL)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# ---------------------------------------------------------------------------
|
| 128 |
-
# Gen mode: original autoregressive path (Fix F: batch decode)
|
| 129 |
-
# ---------------------------------------------------------------------------
|
| 130 |
-
|
| 131 |
-
def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
|
| 132 |
-
"""Original autoregressive generation path with batch decode optimization:
|
| 133 |
-
all GPU work runs first, then all CPU decoding happens after."""
|
| 134 |
-
print("---")
|
| 135 |
-
print("factual_english_samples: (gen mode)")
|
| 136 |
-
model.eval()
|
| 137 |
-
|
| 138 |
-
num_samples = FACTUAL_SAMPLES
|
| 139 |
-
batch = FACTUAL_BATCH
|
| 140 |
-
gen_tokens = FACTUAL_GEN_TOKENS
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
"""Dispatch to probe (fast, default) or gen (original) mode.
|
| 212 |
-
|
| 213 |
-
Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
|
| 214 |
-
"""
|
| 215 |
-
if FACTUAL_MODE == "gen":
|
| 216 |
-
return _run_factual_english_gen(model, tokenizer, max_seq_len)
|
| 217 |
-
return _run_factual_english_probe(model, tokenizer, max_seq_len)
|
|
|
|
| 1 |
+
"""Evaluation: factual probes + sampled factual English scoring.
|
| 2 |
+
|
| 3 |
+
Extracted from train.py (W1 modularization). Semantics unchanged.
|
| 4 |
+
|
| 5 |
+
Perf optimizations (eval_perf_fix):
|
| 6 |
+
- Probe mode: single forward per prompt instead of autoregressive gen
|
| 7 |
+
- Batch decode: all GPU work first, all CPU decode after
|
| 8 |
+
- Batched factual probes: single padded forward instead of N sequential
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import re as _re
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from hydra.config import FACTUAL_SAMPLES, FACTUAL_BATCH, FACTUAL_GEN_TOKENS
|
| 19 |
+
|
| 20 |
+
# Default to probe mode (1 forward per prompt); set HYDRA_FACTUAL_MODE=gen for
|
| 21 |
+
# the original autoregressive generation path.
|
| 22 |
+
FACTUAL_MODE = os.environ.get("HYDRA_FACTUAL_MODE", "probe")
|
| 23 |
+
|
| 24 |
+
FACTUAL_EVAL = [
|
| 25 |
+
# Hard factual recall — requires specific knowledge memorization
|
| 26 |
+
("The capital of France is", ["Paris", "paris"]),
|
| 27 |
+
("Water boils at", ["100", "boiling"]),
|
| 28 |
+
("The largest planet in our solar system is", ["Jupiter", "jupiter"]),
|
| 29 |
+
# Easier completions — common collocations / patterns the model may pick up
|
| 30 |
+
("Once upon a", ["time"]),
|
| 31 |
+
("Hello, my name", ["is", "'s"]),
|
| 32 |
+
("The cat sat on the", ["mat", "floor", "rug", "table", "couch", "chair", "ground"]),
|
| 33 |
+
("She opened the door and", ["walked", "saw", "found", "stepped", "looked", "went", "ran"]),
|
| 34 |
+
# Original hard ones kept for completeness
|
| 35 |
+
("The speed of light is approximately", ["299", "300", "186,000", "light speed"]),
|
| 36 |
+
("Two plus two equals", ["4", "four"]),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
_FACTUAL_PROBES = [
|
| 40 |
+
"The capital of France is",
|
| 41 |
+
"Water boils at",
|
| 42 |
+
"The largest planet in our solar system is",
|
| 43 |
+
"The speed of light is approximately",
|
| 44 |
+
"Shakespeare wrote",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_factual_probes(model, tokenizer, device, autocast_ctx) -> None:
|
| 49 |
+
"""Top-5 next-token predictions for canonical factual prompts.
|
| 50 |
+
|
| 51 |
+
Batched: pads all prompts into a single forward pass instead of N
|
| 52 |
+
sequential passes.
|
| 53 |
+
"""
|
| 54 |
+
print("\n--- Factual Probes ---")
|
| 55 |
+
model.eval()
|
| 56 |
+
|
| 57 |
+
# Process probes one at a time to avoid cooperative launch limit
|
| 58 |
+
# (batched forward with B=len(probes) can exceed SM residency cap).
|
| 59 |
+
for prompt_text in _FACTUAL_PROBES:
|
| 60 |
+
ids = tokenizer.encode(prompt_text)
|
| 61 |
+
x = torch.tensor([ids], device=device)
|
| 62 |
+
with torch.no_grad(), autocast_ctx:
|
| 63 |
+
logits = model(x)
|
| 64 |
+
probs = torch.softmax(logits[0, -1].float(), dim=-1)
|
| 65 |
+
top5 = torch.topk(probs, 5)
|
| 66 |
+
completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
|
| 67 |
+
probs_list = [f"{p:.4f}" for p in top5.values[:3].tolist()]
|
| 68 |
+
print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})')
|
| 69 |
+
print("--- End Factual Probes ---\n")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Probe mode: single forward per prompt (Fix D)
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def _run_factual_english_probe(model, tokenizer, max_seq_len: int):
|
| 77 |
+
"""Fast probe mode: for each (prompt, answers), encode prompt + each answer
|
| 78 |
+
candidate as a single sequence, do ONE forward pass, and check if the model's
|
| 79 |
+
argmax at the last prompt token matches the first answer token.
|
| 80 |
+
|
| 81 |
+
Falls back to checking top-K predictions to be generous (same as gen mode
|
| 82 |
+
which samples multiple temperatures).
|
| 83 |
+
"""
|
| 84 |
+
print("---")
|
| 85 |
+
print("factual_english_samples: (probe mode)")
|
| 86 |
+
model.eval()
|
| 87 |
+
hits = 0
|
| 88 |
+
|
| 89 |
+
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 90 |
+
for prompt, answers in FACTUAL_EVAL:
|
| 91 |
+
prompt_ids = tokenizer.encode(prompt)
|
| 92 |
+
prompt_len = len(prompt_ids)
|
| 93 |
+
x = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
|
| 94 |
+
logits = model(x, targets=None)
|
| 95 |
+
# logits shape: [1, seq_len, vocab] or [1, vocab]
|
| 96 |
+
if logits.dim() == 3:
|
| 97 |
+
last_logits = logits[0, -1, :]
|
| 98 |
+
else:
|
| 99 |
+
last_logits = logits[0]
|
| 100 |
+
|
| 101 |
+
probs = torch.softmax(last_logits.float(), dim=-1)
|
| 102 |
+
# Check top-K predictions (generous: K=20 to match multi-sample gen)
|
| 103 |
+
top_k = min(20, probs.shape[-1])
|
| 104 |
+
top_ids = torch.topk(probs, top_k).indices.tolist()
|
| 105 |
+
top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
|
| 106 |
+
|
| 107 |
+
answers_lower = [a.lower() for a in answers]
|
| 108 |
+
any_hit = any(
|
| 109 |
+
any(a in tok for a in answers_lower)
|
| 110 |
+
for tok in top_tokens
|
| 111 |
+
)
|
| 112 |
+
if any_hit:
|
| 113 |
+
hits += 1
|
| 114 |
+
|
| 115 |
+
best_completion = tokenizer.decode([top_ids[0]])
|
| 116 |
+
print(f" prompt: {prompt!r}")
|
| 117 |
+
print(f" output: {(prompt + best_completion).replace(chr(10), ' ')!r}")
|
| 118 |
+
print(f" hit: {any_hit} (probe top-{top_k})")
|
| 119 |
+
|
| 120 |
+
score = hits / len(FACTUAL_EVAL)
|
| 121 |
+
print("---")
|
| 122 |
+
print(f"factual_english_score: {score:.4f}")
|
| 123 |
+
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
|
| 124 |
+
return score, hits, len(FACTUAL_EVAL)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# Gen mode: original autoregressive path (Fix F: batch decode)
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
|
| 132 |
+
"""Original autoregressive generation path with batch decode optimization:
|
| 133 |
+
all GPU work runs first, then all CPU decoding happens after."""
|
| 134 |
+
print("---")
|
| 135 |
+
print("factual_english_samples: (gen mode)")
|
| 136 |
+
model.eval()
|
| 137 |
+
|
| 138 |
+
num_samples = FACTUAL_SAMPLES
|
| 139 |
+
batch = FACTUAL_BATCH
|
| 140 |
+
gen_tokens = FACTUAL_GEN_TOKENS
|
| 141 |
+
temps = [0.7, 0.9, 1.1]
|
| 142 |
+
hits = 0
|
| 143 |
+
|
| 144 |
+
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 145 |
+
for prompt, answers in FACTUAL_EVAL:
|
| 146 |
+
ids = tokenizer.encode(prompt)
|
| 147 |
+
answers_lower = [a.lower() for a in answers]
|
| 148 |
+
# Collect all generated token sequences on GPU first
|
| 149 |
+
all_rows: list[list[int]] = []
|
| 150 |
+
samples_done = 0
|
| 151 |
+
batch_idx = 0
|
| 152 |
+
while samples_done < num_samples:
|
| 153 |
+
b = min(batch, num_samples - samples_done)
|
| 154 |
+
temp = temps[batch_idx % len(temps)]
|
| 155 |
+
batch_idx += 1
|
| 156 |
+
ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
|
| 157 |
+
for _ in range(gen_tokens):
|
| 158 |
+
logits = model(ctx, targets=None)
|
| 159 |
+
next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
|
| 160 |
+
probs = torch.softmax(next_logits.float() / temp, dim=-1)
|
| 161 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
| 162 |
+
ctx = torch.cat([ctx, next_id], dim=1)
|
| 163 |
+
if ctx.size(1) >= max_seq_len:
|
| 164 |
+
break
|
| 165 |
+
# Transfer to CPU in one shot, no per-row sync
|
| 166 |
+
all_rows.extend(ctx.cpu().tolist())
|
| 167 |
+
samples_done += b
|
| 168 |
+
|
| 169 |
+
# CPU-side batch decode — no GPU sync between decodes
|
| 170 |
+
any_hit = False
|
| 171 |
+
first_gen = None
|
| 172 |
+
hit_gen = None
|
| 173 |
+
for row in all_rows:
|
| 174 |
+
generated = tokenizer.decode(row)
|
| 175 |
+
continuation = generated[len(prompt):].strip()
|
| 176 |
+
_words = set(w.lower() for w in _re.findall(r"\b[\w'-]+\b", continuation))
|
| 177 |
+
hit = any(a in _words for a in answers_lower)
|
| 178 |
+
if first_gen is None:
|
| 179 |
+
first_gen = generated
|
| 180 |
+
if hit:
|
| 181 |
+
any_hit = True
|
| 182 |
+
if hit_gen is None:
|
| 183 |
+
hit_gen = generated
|
| 184 |
+
if any_hit:
|
| 185 |
+
hits += 1
|
| 186 |
+
print(f" prompt: {prompt!r}")
|
| 187 |
+
print(f" output: {(first_gen or '').replace(chr(10), ' ')!r}")
|
| 188 |
+
print(f" hit: {any_hit} (any of {num_samples} samples, temps={temps}, gen={gen_tokens}tok)")
|
| 189 |
+
if hit_gen is not None and hit_gen != first_gen:
|
| 190 |
+
print(f" hit_sample: {hit_gen.replace(chr(10), ' ')!r}")
|
| 191 |
+
|
| 192 |
+
score = hits / len(FACTUAL_EVAL)
|
| 193 |
+
print("---")
|
| 194 |
+
print(f"factual_english_score: {score:.4f}")
|
| 195 |
+
print(f"factual_english_hits: {hits}/{len(FACTUAL_EVAL)}")
|
| 196 |
+
return score, hits, len(FACTUAL_EVAL)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ---------------------------------------------------------------------------
|
| 200 |
+
# Public entry point
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def run_factual_english(model, tokenizer, max_seq_len: int):
|
| 204 |
+
"""Dispatch to probe (fast, default) or gen (original) mode.
|
| 205 |
+
|
| 206 |
+
Set HYDRA_FACTUAL_MODE=gen to use the autoregressive path.
|
| 207 |
+
"""
|
| 208 |
+
if FACTUAL_MODE == "gen":
|
| 209 |
+
return _run_factual_english_gen(model, tokenizer, max_seq_len)
|
| 210 |
+
return _run_factual_english_probe(model, tokenizer, max_seq_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/hydra/gdn_block.py
CHANGED
|
@@ -1,126 +1,126 @@
|
|
| 1 |
-
"""GDNBlock — Gated Delta Net block, drop-in shape-compatible with Mamba3Block and HyenaBlock.
|
| 2 |
-
|
| 3 |
-
GatedDeltaNet (GDN) reference: arXiv:2412.06464 (ICLR 2025, NVLabs).
|
| 4 |
-
Implementation: flash-linear-attention (fla) library, Triton kernels, sm86-compatible.
|
| 5 |
-
|
| 6 |
-
Interface contract (MUST match how Mamba3/Hyena are called in hydra/model.py):
|
| 7 |
-
block = GDNBlock(d_model, ...)
|
| 8 |
-
y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
|
| 9 |
-
|
| 10 |
-
The surrounding mHC layer does NOT pre-norm before calling this block (the
|
| 11 |
-
raw hidden state is passed in); the block itself applies no input normalization,
|
| 12 |
-
same as HyenaBlock. We return the raw operator output; the mHC layer adds it
|
| 13 |
-
as a residual stream contribution.
|
| 14 |
-
|
| 15 |
-
NO attention, NO softmax-over-sequence-dim. All state is stateless between
|
| 16 |
-
.forward() calls by default (use_cache=False, past_key_values=None).
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
from __future__ import annotations
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
from fla.layers.gated_deltanet import GatedDeltaNet as _GatedDeltaNet
|
| 23 |
-
except ImportError as _fla_err:
|
| 24 |
-
raise ImportError(
|
| 25 |
-
"flash-linear-attention (fla) is required for GDNBlock but could not be imported. "
|
| 26 |
-
"Install it with:\n"
|
| 27 |
-
" pip install flash-linear-attention\n"
|
| 28 |
-
"or from source:\n"
|
| 29 |
-
" pip install git+https://github.com/fla-org/flash-linear-attention.git\n"
|
| 30 |
-
f"Original error: {_fla_err}"
|
| 31 |
-
) from _fla_err
|
| 32 |
-
|
| 33 |
-
import torch
|
| 34 |
-
import torch.nn as nn
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class GDNBlock(nn.Module):
|
| 38 |
-
"""Gated Delta Net block, drop-in shape-compatible with HYDRA's Mamba3Block and HyenaBlock.
|
| 39 |
-
|
| 40 |
-
Wraps `fla.layers.GatedDeltaNet` with the same external API that
|
| 41 |
-
`hydra.hyena_block.HyenaBlock` exposes:
|
| 42 |
-
|
| 43 |
-
forward(x: Tensor[B, T, d_model]) -> Tensor[B, T, d_model]
|
| 44 |
-
|
| 45 |
-
Internal GatedDeltaNet.forward returns a 3-tuple
|
| 46 |
-
(hidden_states, attn_weights, past_key_values); we extract [0] and
|
| 47 |
-
return only the hidden states, keeping the residual stream unchanged.
|
| 48 |
-
|
| 49 |
-
GDN outperforms Mamba-2 on in-context retrieval benchmarks (MQAR, etc.)
|
| 50 |
-
at equal or faster compute, making it a targeted fix for HYDRA's factual
|
| 51 |
-
plateau.
|
| 52 |
-
|
| 53 |
-
Parameter counts are deliberately kept within 2x of a Mamba3 block at the
|
| 54 |
-
same d_model/n_heads to be drop-in affordable.
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(
|
| 58 |
-
self,
|
| 59 |
-
d_model: int,
|
| 60 |
-
n_heads: int = 6,
|
| 61 |
-
mode: str = "chunk", # 'chunk' for training, 'fused_recurrent' for inference
|
| 62 |
-
expand_v: float = 2.0, # value-projection expansion; controls KV memory
|
| 63 |
-
use_short_conv: bool = True,
|
| 64 |
-
conv_size: int = 4,
|
| 65 |
-
):
|
| 66 |
-
super().__init__()
|
| 67 |
-
self.d_model = d_model
|
| 68 |
-
self.n_heads = n_heads
|
| 69 |
-
self.mode = mode
|
| 70 |
-
|
| 71 |
-
# head_dim must divide d_model. GDN uses separate q/k head_dim from v;
|
| 72 |
-
# we set head_dim for q/k such that n_heads * head_dim == d_model.
|
| 73 |
-
if d_model % n_heads != 0:
|
| 74 |
-
raise ValueError(
|
| 75 |
-
f"d_model={d_model} must be divisible by n_heads={n_heads} "
|
| 76 |
-
"so that head_dim = d_model // n_heads is an integer."
|
| 77 |
-
)
|
| 78 |
-
head_dim = d_model // n_heads
|
| 79 |
-
|
| 80 |
-
self.gdn = _GatedDeltaNet(
|
| 81 |
-
hidden_size=d_model,
|
| 82 |
-
expand_v=expand_v,
|
| 83 |
-
head_dim=head_dim,
|
| 84 |
-
num_heads=n_heads,
|
| 85 |
-
mode=mode,
|
| 86 |
-
use_gate=True, # gating is the key architectural feature of GDN
|
| 87 |
-
use_short_conv=use_short_conv,
|
| 88 |
-
conv_size=conv_size,
|
| 89 |
-
layer_idx=None, # no KV-cache layer indexing; we manage state ourselves
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
# ------------------------------------------------------------------
|
| 93 |
-
# Forward
|
| 94 |
-
# ------------------------------------------------------------------
|
| 95 |
-
|
| 96 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
-
"""x: [B, T, d_model] -> y: [B, T, d_model].
|
| 98 |
-
|
| 99 |
-
Passes through GatedDeltaNet with use_cache=False so no recurrent
|
| 100 |
-
state leaks between independent forward() calls (important for
|
| 101 |
-
gradient-accumulation loops and eval).
|
| 102 |
-
"""
|
| 103 |
-
# GatedDeltaNet.forward signature:
|
| 104 |
-
# (hidden_states, attention_mask=None, past_key_values=None,
|
| 105 |
-
# use_cache=False, output_attentions=False)
|
| 106 |
-
# Returns: tuple(hidden_states, attn_weights|None, past_kv|None)
|
| 107 |
-
out, _, _ = self.gdn(
|
| 108 |
-
hidden_states=x,
|
| 109 |
-
attention_mask=None,
|
| 110 |
-
past_key_values=None,
|
| 111 |
-
use_cache=False,
|
| 112 |
-
output_attentions=False,
|
| 113 |
-
)
|
| 114 |
-
return out
|
| 115 |
-
|
| 116 |
-
# ------------------------------------------------------------------
|
| 117 |
-
# API parity with HyenaBlock and Mamba3Block
|
| 118 |
-
# ------------------------------------------------------------------
|
| 119 |
-
|
| 120 |
-
def invalidate_caches(self) -> None:
|
| 121 |
-
"""No-op — GDNBlock holds no persistent filter cache.
|
| 122 |
-
|
| 123 |
-
Provided for API parity with HyenaBlock, which invalidates its
|
| 124 |
-
Hyena filter cache here. Calling this is always safe.
|
| 125 |
-
"""
|
| 126 |
-
pass
|
|
|
|
| 1 |
+
"""GDNBlock — Gated Delta Net block, drop-in shape-compatible with Mamba3Block and HyenaBlock.
|
| 2 |
+
|
| 3 |
+
GatedDeltaNet (GDN) reference: arXiv:2412.06464 (ICLR 2025, NVLabs).
|
| 4 |
+
Implementation: flash-linear-attention (fla) library, Triton kernels, sm86-compatible.
|
| 5 |
+
|
| 6 |
+
Interface contract (MUST match how Mamba3/Hyena are called in hydra/model.py):
|
| 7 |
+
block = GDNBlock(d_model, ...)
|
| 8 |
+
y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
|
| 9 |
+
|
| 10 |
+
The surrounding mHC layer does NOT pre-norm before calling this block (the
|
| 11 |
+
raw hidden state is passed in); the block itself applies no input normalization,
|
| 12 |
+
same as HyenaBlock. We return the raw operator output; the mHC layer adds it
|
| 13 |
+
as a residual stream contribution.
|
| 14 |
+
|
| 15 |
+
NO attention, NO softmax-over-sequence-dim. All state is stateless between
|
| 16 |
+
.forward() calls by default (use_cache=False, past_key_values=None).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from fla.layers.gated_deltanet import GatedDeltaNet as _GatedDeltaNet
|
| 23 |
+
except ImportError as _fla_err:
|
| 24 |
+
raise ImportError(
|
| 25 |
+
"flash-linear-attention (fla) is required for GDNBlock but could not be imported. "
|
| 26 |
+
"Install it with:\n"
|
| 27 |
+
" pip install flash-linear-attention\n"
|
| 28 |
+
"or from source:\n"
|
| 29 |
+
" pip install git+https://github.com/fla-org/flash-linear-attention.git\n"
|
| 30 |
+
f"Original error: {_fla_err}"
|
| 31 |
+
) from _fla_err
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GDNBlock(nn.Module):
|
| 38 |
+
"""Gated Delta Net block, drop-in shape-compatible with HYDRA's Mamba3Block and HyenaBlock.
|
| 39 |
+
|
| 40 |
+
Wraps `fla.layers.GatedDeltaNet` with the same external API that
|
| 41 |
+
`hydra.hyena_block.HyenaBlock` exposes:
|
| 42 |
+
|
| 43 |
+
forward(x: Tensor[B, T, d_model]) -> Tensor[B, T, d_model]
|
| 44 |
+
|
| 45 |
+
Internal GatedDeltaNet.forward returns a 3-tuple
|
| 46 |
+
(hidden_states, attn_weights, past_key_values); we extract [0] and
|
| 47 |
+
return only the hidden states, keeping the residual stream unchanged.
|
| 48 |
+
|
| 49 |
+
GDN outperforms Mamba-2 on in-context retrieval benchmarks (MQAR, etc.)
|
| 50 |
+
at equal or faster compute, making it a targeted fix for HYDRA's factual
|
| 51 |
+
plateau.
|
| 52 |
+
|
| 53 |
+
Parameter counts are deliberately kept within 2x of a Mamba3 block at the
|
| 54 |
+
same d_model/n_heads to be drop-in affordable.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
d_model: int,
|
| 60 |
+
n_heads: int = 6,
|
| 61 |
+
mode: str = "chunk", # 'chunk' for training, 'fused_recurrent' for inference
|
| 62 |
+
expand_v: float = 2.0, # value-projection expansion; controls KV memory
|
| 63 |
+
use_short_conv: bool = True,
|
| 64 |
+
conv_size: int = 4,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.d_model = d_model
|
| 68 |
+
self.n_heads = n_heads
|
| 69 |
+
self.mode = mode
|
| 70 |
+
|
| 71 |
+
# head_dim must divide d_model. GDN uses separate q/k head_dim from v;
|
| 72 |
+
# we set head_dim for q/k such that n_heads * head_dim == d_model.
|
| 73 |
+
if d_model % n_heads != 0:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
f"d_model={d_model} must be divisible by n_heads={n_heads} "
|
| 76 |
+
"so that head_dim = d_model // n_heads is an integer."
|
| 77 |
+
)
|
| 78 |
+
head_dim = d_model // n_heads
|
| 79 |
+
|
| 80 |
+
self.gdn = _GatedDeltaNet(
|
| 81 |
+
hidden_size=d_model,
|
| 82 |
+
expand_v=expand_v,
|
| 83 |
+
head_dim=head_dim,
|
| 84 |
+
num_heads=n_heads,
|
| 85 |
+
mode=mode,
|
| 86 |
+
use_gate=True, # gating is the key architectural feature of GDN
|
| 87 |
+
use_short_conv=use_short_conv,
|
| 88 |
+
conv_size=conv_size,
|
| 89 |
+
layer_idx=None, # no KV-cache layer indexing; we manage state ourselves
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ------------------------------------------------------------------
|
| 93 |
+
# Forward
|
| 94 |
+
# ------------------------------------------------------------------
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
"""x: [B, T, d_model] -> y: [B, T, d_model].
|
| 98 |
+
|
| 99 |
+
Passes through GatedDeltaNet with use_cache=False so no recurrent
|
| 100 |
+
state leaks between independent forward() calls (important for
|
| 101 |
+
gradient-accumulation loops and eval).
|
| 102 |
+
"""
|
| 103 |
+
# GatedDeltaNet.forward signature:
|
| 104 |
+
# (hidden_states, attention_mask=None, past_key_values=None,
|
| 105 |
+
# use_cache=False, output_attentions=False)
|
| 106 |
+
# Returns: tuple(hidden_states, attn_weights|None, past_kv|None)
|
| 107 |
+
out, _, _ = self.gdn(
|
| 108 |
+
hidden_states=x,
|
| 109 |
+
attention_mask=None,
|
| 110 |
+
past_key_values=None,
|
| 111 |
+
use_cache=False,
|
| 112 |
+
output_attentions=False,
|
| 113 |
+
)
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
# ------------------------------------------------------------------
|
| 117 |
+
# API parity with HyenaBlock and Mamba3Block
|
| 118 |
+
# ------------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
def invalidate_caches(self) -> None:
|
| 121 |
+
"""No-op — GDNBlock holds no persistent filter cache.
|
| 122 |
+
|
| 123 |
+
Provided for API parity with HyenaBlock, which invalidates its
|
| 124 |
+
Hyena filter cache here. Calling this is always safe.
|
| 125 |
+
"""
|
| 126 |
+
pass
|
overlay/hydra/hyena_block.py
CHANGED
|
@@ -1,68 +1,68 @@
|
|
| 1 |
-
"""HyenaBlock — drop-in block for HYDRA, supplement to Mamba3.
|
| 2 |
-
|
| 3 |
-
Wraps `subsystems.hyena_pure.HyenaOperator` with a pre-norm + residual scheme
|
| 4 |
-
consistent with how the mHC stack wraps Mamba3 in `hydra/model.py`.
|
| 5 |
-
|
| 6 |
-
Interface contract (MUST match how Mamba3 is called in model.py):
|
| 7 |
-
block = HyenaBlock(d_model, seq_len)
|
| 8 |
-
y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
|
| 9 |
-
|
| 10 |
-
The surrounding mHC layer does the pre-norm (`norm(h)`) BEFORE calling the
|
| 11 |
-
block, so the block itself should NOT re-normalize at input — same as Mamba3
|
| 12 |
-
in the current model. We return the raw operator output; the mHC layer then
|
| 13 |
-
adds it as a residual stream contribution.
|
| 14 |
-
|
| 15 |
-
NO attention, NO softmax-over-sequence-dim, NO KV-cache. All forbidden
|
| 16 |
-
imports enumerated in tests/test_hyena.py (test #7) are absent.
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
from __future__ import annotations
|
| 20 |
-
|
| 21 |
-
import os
|
| 22 |
-
|
| 23 |
-
import torch
|
| 24 |
-
import torch.nn as nn
|
| 25 |
-
|
| 26 |
-
from subsystems.hyena_pure import HyenaOperator
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class HyenaBlock(nn.Module):
|
| 30 |
-
"""Single Hyena block, shape-compatible with Mamba3 in HYDRA."""
|
| 31 |
-
|
| 32 |
-
def __init__(
|
| 33 |
-
self,
|
| 34 |
-
d_model: int,
|
| 35 |
-
seq_len: int,
|
| 36 |
-
order: int | None = None,
|
| 37 |
-
filter_order: int | None = None,
|
| 38 |
-
dropout: float = 0.0,
|
| 39 |
-
filter_dropout: float = 0.0,
|
| 40 |
-
short_filter_order: int = 3,
|
| 41 |
-
activation: str = "id",
|
| 42 |
-
):
|
| 43 |
-
super().__init__()
|
| 44 |
-
# Env overrides (documented in hydra/config.py).
|
| 45 |
-
if order is None:
|
| 46 |
-
order = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
|
| 47 |
-
if filter_order is None:
|
| 48 |
-
filter_order = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
|
| 49 |
-
|
| 50 |
-
self.d_model = d_model
|
| 51 |
-
self.seq_len = seq_len
|
| 52 |
-
self.order = order
|
| 53 |
-
self.filter_order = filter_order
|
| 54 |
-
|
| 55 |
-
self.operator = HyenaOperator(
|
| 56 |
-
d_model=d_model,
|
| 57 |
-
l_max=seq_len,
|
| 58 |
-
order=order,
|
| 59 |
-
filter_order=filter_order,
|
| 60 |
-
dropout=dropout,
|
| 61 |
-
filter_dropout=filter_dropout,
|
| 62 |
-
short_filter_order=short_filter_order,
|
| 63 |
-
activation=activation,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
-
"""x: [B, T, d_model] -> y: [B, T, d_model]."""
|
| 68 |
-
return self.operator(x)
|
|
|
|
| 1 |
+
"""HyenaBlock — drop-in block for HYDRA, supplement to Mamba3.
|
| 2 |
+
|
| 3 |
+
Wraps `subsystems.hyena_pure.HyenaOperator` with a pre-norm + residual scheme
|
| 4 |
+
consistent with how the mHC stack wraps Mamba3 in `hydra/model.py`.
|
| 5 |
+
|
| 6 |
+
Interface contract (MUST match how Mamba3 is called in model.py):
|
| 7 |
+
block = HyenaBlock(d_model, seq_len)
|
| 8 |
+
y = block(x) # x: [B, T, d_model] -> y: [B, T, d_model]
|
| 9 |
+
|
| 10 |
+
The surrounding mHC layer does the pre-norm (`norm(h)`) BEFORE calling the
|
| 11 |
+
block, so the block itself should NOT re-normalize at input — same as Mamba3
|
| 12 |
+
in the current model. We return the raw operator output; the mHC layer then
|
| 13 |
+
adds it as a residual stream contribution.
|
| 14 |
+
|
| 15 |
+
NO attention, NO softmax-over-sequence-dim, NO KV-cache. All forbidden
|
| 16 |
+
imports enumerated in tests/test_hyena.py (test #7) are absent.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
|
| 26 |
+
from subsystems.hyena_pure import HyenaOperator
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HyenaBlock(nn.Module):
|
| 30 |
+
"""Single Hyena block, shape-compatible with Mamba3 in HYDRA."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
d_model: int,
|
| 35 |
+
seq_len: int,
|
| 36 |
+
order: int | None = None,
|
| 37 |
+
filter_order: int | None = None,
|
| 38 |
+
dropout: float = 0.0,
|
| 39 |
+
filter_dropout: float = 0.0,
|
| 40 |
+
short_filter_order: int = 3,
|
| 41 |
+
activation: str = "id",
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
# Env overrides (documented in hydra/config.py).
|
| 45 |
+
if order is None:
|
| 46 |
+
order = int(os.environ.get("HYDRA_HYENA_ORDER", "2"))
|
| 47 |
+
if filter_order is None:
|
| 48 |
+
filter_order = int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64"))
|
| 49 |
+
|
| 50 |
+
self.d_model = d_model
|
| 51 |
+
self.seq_len = seq_len
|
| 52 |
+
self.order = order
|
| 53 |
+
self.filter_order = filter_order
|
| 54 |
+
|
| 55 |
+
self.operator = HyenaOperator(
|
| 56 |
+
d_model=d_model,
|
| 57 |
+
l_max=seq_len,
|
| 58 |
+
order=order,
|
| 59 |
+
filter_order=filter_order,
|
| 60 |
+
dropout=dropout,
|
| 61 |
+
filter_dropout=filter_dropout,
|
| 62 |
+
short_filter_order=short_filter_order,
|
| 63 |
+
activation=activation,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
"""x: [B, T, d_model] -> y: [B, T, d_model]."""
|
| 68 |
+
return self.operator(x)
|
overlay/hydra/lightning_module.py
CHANGED
|
@@ -1,326 +1,326 @@
|
|
| 1 |
-
"""LightningModule wrapping PostSemClawModel.
|
| 2 |
-
|
| 3 |
-
Thin adapter. The model and the MuonAdamW optimizer are unchanged. This
|
| 4 |
-
module implements:
|
| 5 |
-
|
| 6 |
-
• configure_optimizers — returns the existing MuonAdamW (subclass of
|
| 7 |
-
torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts
|
| 8 |
-
this directly.
|
| 9 |
-
• training_step — splits (B, T+1) batches into (x, y), forwards through
|
| 10 |
-
the model, logs loss / bpb / tps / mfu / vram. Preserves the
|
| 11 |
-
sampled-softmax path inside PostSemClawModel (no changes there).
|
| 12 |
-
• optimizer_step — before each step we update LR + muon momentum + WD
|
| 13 |
-
using the same time-progress schedule as hydra/training.py
|
| 14 |
-
(get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning
|
| 15 |
-
handles grad accumulation via Trainer(accumulate_grad_batches=N).
|
| 16 |
-
|
| 17 |
-
The SDR SOM update and Hestia QAT snap are called at the same cadence as
|
| 18 |
-
the legacy loop, but inline on the main thread (Lightning provides its own
|
| 19 |
-
callbacks for async work if we need to extract them later — keeping it
|
| 20 |
-
simple for now).
|
| 21 |
-
|
| 22 |
-
Env vars respected:
|
| 23 |
-
HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule
|
| 24 |
-
and as Trainer max_time
|
| 25 |
-
HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100)
|
| 26 |
-
HYDRA_BATCH_SIZE — device batch size (for throughput calc)
|
| 27 |
-
HYDRA_SEQ_LEN — sequence length (for throughput calc)
|
| 28 |
-
"""
|
| 29 |
-
from __future__ import annotations
|
| 30 |
-
|
| 31 |
-
import math
|
| 32 |
-
import os
|
| 33 |
-
import time
|
| 34 |
-
|
| 35 |
-
import torch
|
| 36 |
-
import lightning as L
|
| 37 |
-
|
| 38 |
-
from hydra.config import (
|
| 39 |
-
ADAM_BETAS,
|
| 40 |
-
EMBEDDING_LR,
|
| 41 |
-
FINAL_LR_FRAC,
|
| 42 |
-
GPU_BF16_PEAK_FLOPS,
|
| 43 |
-
MATRIX_LR,
|
| 44 |
-
SCALAR_LR,
|
| 45 |
-
UNEMBEDDING_LR,
|
| 46 |
-
WARMUP_RATIO,
|
| 47 |
-
WEIGHT_DECAY,
|
| 48 |
-
PostSemClawConfig,
|
| 49 |
-
)
|
| 50 |
-
from hydra.model import PostSemClawModel
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# ---------------------------------------------------------------------------
|
| 54 |
-
# LR / momentum / wd schedules — verbatim copy of hydra/training.py so the
|
| 55 |
-
# curves match exactly. Kept here to avoid import cycles.
|
| 56 |
-
# ---------------------------------------------------------------------------
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _lr_multiplier(progress: float) -> float:
|
| 60 |
-
if progress < WARMUP_RATIO:
|
| 61 |
-
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
| 62 |
-
decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9)
|
| 63 |
-
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (
|
| 64 |
-
1 + math.cos(math.pi * decay_progress)
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def _muon_momentum(step: int) -> float:
|
| 69 |
-
frac = min(step / 300.0, 1.0)
|
| 70 |
-
return (1 - frac) * 0.85 + frac * 0.95
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _weight_decay(progress: float) -> float:
|
| 74 |
-
return WEIGHT_DECAY * (1 - progress)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
# ---------------------------------------------------------------------------
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class HydraLightningModule(L.LightningModule):
|
| 81 |
-
"""Lightning wrapper. Public attrs: self.model, self.config."""
|
| 82 |
-
|
| 83 |
-
def __init__(self, config: PostSemClawConfig):
|
| 84 |
-
super().__init__()
|
| 85 |
-
self.config = config
|
| 86 |
-
self.model = PostSemClawModel(config)
|
| 87 |
-
# Model weights init must be deferred to the correct device; done by
|
| 88 |
-
# caller after construction (to match the meta-device + to_empty()
|
| 89 |
-
# pattern used in the legacy loop).
|
| 90 |
-
|
| 91 |
-
# Time-based progress tracks the legacy loop's semantics: LR cosine
|
| 92 |
-
# is driven by wall-clock, not step count. We capture training start
|
| 93 |
-
# in on_train_start and TIME_BUDGET from env.
|
| 94 |
-
self.time_budget = float(
|
| 95 |
-
int(os.environ.get("HYDRA_TIME_BUDGET", "300"))
|
| 96 |
-
)
|
| 97 |
-
self._train_start_time: float | None = None
|
| 98 |
-
self._total_training_time = 0.0
|
| 99 |
-
self._last_step_end: float | None = None
|
| 100 |
-
self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
|
| 101 |
-
self._flops_per_token = 0
|
| 102 |
-
self._tokens_per_step = 0
|
| 103 |
-
|
| 104 |
-
# Smoothed loss for the header-line log (matches legacy format).
|
| 105 |
-
self._ema_beta = 0.9
|
| 106 |
-
self._smooth_loss = 0.0
|
| 107 |
-
self._bpt_ema = 0.0
|
| 108 |
-
self._token_bytes: torch.Tensor | None = None
|
| 109 |
-
|
| 110 |
-
# ------------------------------------------------------------------
|
| 111 |
-
# Lifecycle
|
| 112 |
-
# ------------------------------------------------------------------
|
| 113 |
-
|
| 114 |
-
def on_train_start(self) -> None:
|
| 115 |
-
self._train_start_time = time.time()
|
| 116 |
-
self._last_step_end = self._train_start_time
|
| 117 |
-
self._flops_per_token = self.model.estimate_flops()
|
| 118 |
-
# Tokens processed per optimizer step (pre-accum).
|
| 119 |
-
B = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
|
| 120 |
-
T = int(os.environ.get("HYDRA_SEQ_LEN", "512"))
|
| 121 |
-
self._tokens_per_step = B * T
|
| 122 |
-
|
| 123 |
-
# Build/cache token_bytes LUT (for bits-per-byte live metric).
|
| 124 |
-
import prepare as _p
|
| 125 |
-
self._token_bytes = _p.get_token_bytes(device=self.device)
|
| 126 |
-
|
| 127 |
-
def configure_optimizers(self):
|
| 128 |
-
optimizer = self.model.setup_optimizer(
|
| 129 |
-
unembedding_lr=UNEMBEDDING_LR,
|
| 130 |
-
embedding_lr=EMBEDDING_LR,
|
| 131 |
-
scalar_lr=SCALAR_LR,
|
| 132 |
-
adam_betas=ADAM_BETAS,
|
| 133 |
-
matrix_lr=MATRIX_LR,
|
| 134 |
-
weight_decay=WEIGHT_DECAY,
|
| 135 |
-
)
|
| 136 |
-
return optimizer
|
| 137 |
-
|
| 138 |
-
# ------------------------------------------------------------------
|
| 139 |
-
# Training step. Lightning auto-handles: autocast (via precision flag
|
| 140 |
-
# on Trainer), backward, grad-accum, zero_grad. We only:
|
| 141 |
-
# - split batch into (x, y)
|
| 142 |
-
# - forward through model (autocast is established by Trainer)
|
| 143 |
-
# - return loss (grads flow from return)
|
| 144 |
-
# ------------------------------------------------------------------
|
| 145 |
-
|
| 146 |
-
def training_step(self, batch: torch.Tensor, batch_idx: int):
|
| 147 |
-
# DataLoader produces (B, T+1) rows; split into input/target.
|
| 148 |
-
# Lightning's default collate already moved batch to self.device via
|
| 149 |
-
# the accelerator callback when pin_memory=True and device != cpu.
|
| 150 |
-
if batch.dim() != 2:
|
| 151 |
-
raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}")
|
| 152 |
-
x = batch[:, :-1].contiguous()
|
| 153 |
-
y = batch[:, 1:].contiguous()
|
| 154 |
-
|
| 155 |
-
loss = self.model(x, y)
|
| 156 |
-
# Lightning applies the grad-accum divisor automatically; we just
|
| 157 |
-
# return the raw loss. loss.detach() is stored for logging.
|
| 158 |
-
self._log_step(loss.detach(), y)
|
| 159 |
-
return loss
|
| 160 |
-
|
| 161 |
-
# ------------------------------------------------------------------
|
| 162 |
-
# Optimizer step hook: update LR / momentum / WD using time-progress.
|
| 163 |
-
# Runs once per optimizer step (after all accum micro-batches).
|
| 164 |
-
# ------------------------------------------------------------------
|
| 165 |
-
|
| 166 |
-
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
|
| 167 |
-
# Update schedules from wall-clock progress.
|
| 168 |
-
now = time.time()
|
| 169 |
-
if self._train_start_time is None:
|
| 170 |
-
self._train_start_time = now
|
| 171 |
-
self._last_step_end = now
|
| 172 |
-
progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0)
|
| 173 |
-
|
| 174 |
-
step = self.global_step
|
| 175 |
-
lrm = _lr_multiplier(progress)
|
| 176 |
-
mom = _muon_momentum(step)
|
| 177 |
-
wd = _weight_decay(progress)
|
| 178 |
-
for group in optimizer.param_groups:
|
| 179 |
-
group["lr"] = group["initial_lr"] * lrm
|
| 180 |
-
if group.get("kind") == "muon":
|
| 181 |
-
group["momentum"] = mom
|
| 182 |
-
group["weight_decay"] = wd
|
| 183 |
-
|
| 184 |
-
# Grad clip (matches legacy loop). Lightning provides this via
|
| 185 |
-
# Trainer(gradient_clip_val=1.0) but we want the exact call-site.
|
| 186 |
-
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
| 187 |
-
|
| 188 |
-
# Hyena train-cache: we must flush accumulated micro-batch grads BACK
|
| 189 |
-
# into the filter MLP params AFTER the accum-backward closure has run
|
| 190 |
-
# but BEFORE the optimizer actually consumes the grads. Lightning
|
| 191 |
-
# composes these so the closure runs inside optimizer.step(). We wrap
|
| 192 |
-
# the closure to insert our flush at the exact right moment.
|
| 193 |
-
#
|
| 194 |
-
# Ordering within the wrapped closure:
|
| 195 |
-
# 1. optimizer_closure() — runs all micro-batch forwards + backwards.
|
| 196 |
-
# Each Hyena micro-batch backward accumulates into _k_leaf.grad.
|
| 197 |
-
# 2. flush_hyena_pending_grads() — one-shot
|
| 198 |
-
# torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter.
|
| 199 |
-
# Now filter MLP / pos_emb / bias params have their correct grads.
|
| 200 |
-
#
|
| 201 |
-
# No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist.
|
| 202 |
-
_has_flush = hasattr(self.model, "flush_hyena_pending_grads")
|
| 203 |
-
if _has_flush:
|
| 204 |
-
_orig_closure = optimizer_closure
|
| 205 |
-
|
| 206 |
-
def _wrapped_closure():
|
| 207 |
-
result = _orig_closure()
|
| 208 |
-
self.model.flush_hyena_pending_grads()
|
| 209 |
-
return result
|
| 210 |
-
|
| 211 |
-
effective_closure = _wrapped_closure
|
| 212 |
-
else:
|
| 213 |
-
effective_closure = optimizer_closure
|
| 214 |
-
|
| 215 |
-
# Run the step (this is what Lightning would have done for us).
|
| 216 |
-
optimizer.step(closure=effective_closure)
|
| 217 |
-
self.model.zero_grad(set_to_none=True)
|
| 218 |
-
|
| 219 |
-
# Hyena filter-rfft cache invalidation. No-op if:
|
| 220 |
-
# (a) no Hyena layers are in the model, or
|
| 221 |
-
# (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0
|
| 222 |
-
# (the operators never populated either cache)
|
| 223 |
-
# In either case this is a handful of Python attribute resets.
|
| 224 |
-
if hasattr(self.model, "invalidate_hyena_caches"):
|
| 225 |
-
self.model.invalidate_hyena_caches()
|
| 226 |
-
|
| 227 |
-
# Hestia QAT snap every N steps. Temperature anneals every step.
|
| 228 |
-
progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0)
|
| 229 |
-
self.model.hestia.anneal_temperature(progress_now)
|
| 230 |
-
if self._hestia_interval > 0 and step % self._hestia_interval == 0:
|
| 231 |
-
self.model.hestia.apply_to(self.model)
|
| 232 |
-
|
| 233 |
-
# SDR SOM update when the model stashed an sdr in the last forward.
|
| 234 |
-
_last_sdr = getattr(self.model, "_last_sdr", None)
|
| 235 |
-
if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"):
|
| 236 |
-
# x from the last training_step is not available here without
|
| 237 |
-
# captured state; the legacy loop passed (x, _last_sdr). To keep
|
| 238 |
-
# the interface clean we pass the last batch's x via a buffer.
|
| 239 |
-
# Since _last_sdr is derived from idx, we reuse self._last_x.
|
| 240 |
-
if getattr(self, "_last_x", None) is not None:
|
| 241 |
-
self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr)
|
| 242 |
-
|
| 243 |
-
# Advance the wall-clock counter for LR schedule (matches legacy
|
| 244 |
-
# behavior which incremented only after the first warm-up step).
|
| 245 |
-
dt = now - (self._last_step_end or now)
|
| 246 |
-
self._last_step_end = now
|
| 247 |
-
if step > 10:
|
| 248 |
-
self._total_training_time += dt
|
| 249 |
-
|
| 250 |
-
# ------------------------------------------------------------------
|
| 251 |
-
# Logging — mirrors the step=NNNNN line format of the legacy loop so
|
| 252 |
-
# grep/tee pipelines keep working.
|
| 253 |
-
# ------------------------------------------------------------------
|
| 254 |
-
|
| 255 |
-
def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None:
|
| 256 |
-
# Stash the current x so optimizer_step can drive SOM update.
|
| 257 |
-
self._last_x = None # reset; we will set it below.
|
| 258 |
-
# We don't have x here (already discarded); emit a None marker that
|
| 259 |
-
# the SOM hook will silently skip if absent.
|
| 260 |
-
|
| 261 |
-
loss_f = float(loss.item())
|
| 262 |
-
if not math.isfinite(loss_f) or loss_f > 100:
|
| 263 |
-
# Let Lightning raise / the trainer callbacks handle this.
|
| 264 |
-
self.log("train_loss_nan", 1.0)
|
| 265 |
-
return
|
| 266 |
-
|
| 267 |
-
step = self.global_step
|
| 268 |
-
self._smooth_loss = (
|
| 269 |
-
self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f
|
| 270 |
-
)
|
| 271 |
-
debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9)
|
| 272 |
-
dt = max(time.time() - (self._last_step_end or time.time()), 1e-6)
|
| 273 |
-
tps = int(self._tokens_per_step / dt) if dt > 0 else 0
|
| 274 |
-
mfu = (
|
| 275 |
-
100.0
|
| 276 |
-
* self._flops_per_token
|
| 277 |
-
* self._tokens_per_step
|
| 278 |
-
/ dt
|
| 279 |
-
/ GPU_BF16_PEAK_FLOPS
|
| 280 |
-
if dt > 0
|
| 281 |
-
else 0.0
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
# bpb live: y flat -> token_bytes LUT -> avg bytes/token
|
| 285 |
-
bpt = debiased / math.log(2)
|
| 286 |
-
if self._token_bytes is not None:
|
| 287 |
-
with torch.no_grad():
|
| 288 |
-
y_flat = y.reshape(-1)
|
| 289 |
-
nbytes = self._token_bytes[y_flat]
|
| 290 |
-
mask = nbytes > 0
|
| 291 |
-
denom = mask.sum().clamp(min=1).float()
|
| 292 |
-
avg_bpt = (nbytes.float() * mask.float()).sum() / denom
|
| 293 |
-
bpt_batch = float(avg_bpt.item())
|
| 294 |
-
if step == 0 or self._bpt_ema <= 0.0:
|
| 295 |
-
self._bpt_ema = bpt_batch
|
| 296 |
-
else:
|
| 297 |
-
self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch
|
| 298 |
-
bpb = bpt / max(self._bpt_ema, 1e-6)
|
| 299 |
-
vram = (
|
| 300 |
-
torch.cuda.memory_allocated() / 1024 / 1024
|
| 301 |
-
if torch.cuda.is_available()
|
| 302 |
-
else 0.0
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
self.log_dict(
|
| 306 |
-
{
|
| 307 |
-
"train/loss": debiased,
|
| 308 |
-
"train/bpb": bpb,
|
| 309 |
-
"train/bpt": bpt,
|
| 310 |
-
"train/tps": float(tps),
|
| 311 |
-
"train/mfu": float(mfu),
|
| 312 |
-
"train/vram_mib": float(vram),
|
| 313 |
-
},
|
| 314 |
-
prog_bar=False,
|
| 315 |
-
on_step=True,
|
| 316 |
-
on_epoch=False,
|
| 317 |
-
)
|
| 318 |
-
|
| 319 |
-
# Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..."
|
| 320 |
-
print(
|
| 321 |
-
f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} "
|
| 322 |
-
f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} "
|
| 323 |
-
f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
|
| 324 |
-
f"vram={vram:.0f}MiB",
|
| 325 |
-
flush=True,
|
| 326 |
-
)
|
|
|
|
| 1 |
+
"""LightningModule wrapping PostSemClawModel.
|
| 2 |
+
|
| 3 |
+
Thin adapter. The model and the MuonAdamW optimizer are unchanged. This
|
| 4 |
+
module implements:
|
| 5 |
+
|
| 6 |
+
• configure_optimizers — returns the existing MuonAdamW (subclass of
|
| 7 |
+
torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts
|
| 8 |
+
this directly.
|
| 9 |
+
• training_step — splits (B, T+1) batches into (x, y), forwards through
|
| 10 |
+
the model, logs loss / bpb / tps / mfu / vram. Preserves the
|
| 11 |
+
sampled-softmax path inside PostSemClawModel (no changes there).
|
| 12 |
+
• optimizer_step — before each step we update LR + muon momentum + WD
|
| 13 |
+
using the same time-progress schedule as hydra/training.py
|
| 14 |
+
(get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning
|
| 15 |
+
handles grad accumulation via Trainer(accumulate_grad_batches=N).
|
| 16 |
+
|
| 17 |
+
The SDR SOM update and Hestia QAT snap are called at the same cadence as
|
| 18 |
+
the legacy loop, but inline on the main thread (Lightning provides its own
|
| 19 |
+
callbacks for async work if we need to extract them later — keeping it
|
| 20 |
+
simple for now).
|
| 21 |
+
|
| 22 |
+
Env vars respected:
|
| 23 |
+
HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule
|
| 24 |
+
and as Trainer max_time
|
| 25 |
+
HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100)
|
| 26 |
+
HYDRA_BATCH_SIZE — device batch size (for throughput calc)
|
| 27 |
+
HYDRA_SEQ_LEN — sequence length (for throughput calc)
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import math
|
| 32 |
+
import os
|
| 33 |
+
import time
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import lightning as L
|
| 37 |
+
|
| 38 |
+
from hydra.config import (
|
| 39 |
+
ADAM_BETAS,
|
| 40 |
+
EMBEDDING_LR,
|
| 41 |
+
FINAL_LR_FRAC,
|
| 42 |
+
GPU_BF16_PEAK_FLOPS,
|
| 43 |
+
MATRIX_LR,
|
| 44 |
+
SCALAR_LR,
|
| 45 |
+
UNEMBEDDING_LR,
|
| 46 |
+
WARMUP_RATIO,
|
| 47 |
+
WEIGHT_DECAY,
|
| 48 |
+
PostSemClawConfig,
|
| 49 |
+
)
|
| 50 |
+
from hydra.model import PostSemClawModel
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# LR / momentum / wd schedules — verbatim copy of hydra/training.py so the
|
| 55 |
+
# curves match exactly. Kept here to avoid import cycles.
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _lr_multiplier(progress: float) -> float:
|
| 60 |
+
if progress < WARMUP_RATIO:
|
| 61 |
+
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
| 62 |
+
decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9)
|
| 63 |
+
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (
|
| 64 |
+
1 + math.cos(math.pi * decay_progress)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _muon_momentum(step: int) -> float:
|
| 69 |
+
frac = min(step / 300.0, 1.0)
|
| 70 |
+
return (1 - frac) * 0.85 + frac * 0.95
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _weight_decay(progress: float) -> float:
|
| 74 |
+
return WEIGHT_DECAY * (1 - progress)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class HydraLightningModule(L.LightningModule):
|
| 81 |
+
"""Lightning wrapper. Public attrs: self.model, self.config."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, config: PostSemClawConfig):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.config = config
|
| 86 |
+
self.model = PostSemClawModel(config)
|
| 87 |
+
# Model weights init must be deferred to the correct device; done by
|
| 88 |
+
# caller after construction (to match the meta-device + to_empty()
|
| 89 |
+
# pattern used in the legacy loop).
|
| 90 |
+
|
| 91 |
+
# Time-based progress tracks the legacy loop's semantics: LR cosine
|
| 92 |
+
# is driven by wall-clock, not step count. We capture training start
|
| 93 |
+
# in on_train_start and TIME_BUDGET from env.
|
| 94 |
+
self.time_budget = float(
|
| 95 |
+
int(os.environ.get("HYDRA_TIME_BUDGET", "300"))
|
| 96 |
+
)
|
| 97 |
+
self._train_start_time: float | None = None
|
| 98 |
+
self._total_training_time = 0.0
|
| 99 |
+
self._last_step_end: float | None = None
|
| 100 |
+
self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
|
| 101 |
+
self._flops_per_token = 0
|
| 102 |
+
self._tokens_per_step = 0
|
| 103 |
+
|
| 104 |
+
# Smoothed loss for the header-line log (matches legacy format).
|
| 105 |
+
self._ema_beta = 0.9
|
| 106 |
+
self._smooth_loss = 0.0
|
| 107 |
+
self._bpt_ema = 0.0
|
| 108 |
+
self._token_bytes: torch.Tensor | None = None
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------------
|
| 111 |
+
# Lifecycle
|
| 112 |
+
# ------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
def on_train_start(self) -> None:
|
| 115 |
+
self._train_start_time = time.time()
|
| 116 |
+
self._last_step_end = self._train_start_time
|
| 117 |
+
self._flops_per_token = self.model.estimate_flops()
|
| 118 |
+
# Tokens processed per optimizer step (pre-accum).
|
| 119 |
+
B = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
|
| 120 |
+
T = int(os.environ.get("HYDRA_SEQ_LEN", "512"))
|
| 121 |
+
self._tokens_per_step = B * T
|
| 122 |
+
|
| 123 |
+
# Build/cache token_bytes LUT (for bits-per-byte live metric).
|
| 124 |
+
import prepare as _p
|
| 125 |
+
self._token_bytes = _p.get_token_bytes(device=self.device)
|
| 126 |
+
|
| 127 |
+
def configure_optimizers(self):
|
| 128 |
+
optimizer = self.model.setup_optimizer(
|
| 129 |
+
unembedding_lr=UNEMBEDDING_LR,
|
| 130 |
+
embedding_lr=EMBEDDING_LR,
|
| 131 |
+
scalar_lr=SCALAR_LR,
|
| 132 |
+
adam_betas=ADAM_BETAS,
|
| 133 |
+
matrix_lr=MATRIX_LR,
|
| 134 |
+
weight_decay=WEIGHT_DECAY,
|
| 135 |
+
)
|
| 136 |
+
return optimizer
|
| 137 |
+
|
| 138 |
+
# ------------------------------------------------------------------
|
| 139 |
+
# Training step. Lightning auto-handles: autocast (via precision flag
|
| 140 |
+
# on Trainer), backward, grad-accum, zero_grad. We only:
|
| 141 |
+
# - split batch into (x, y)
|
| 142 |
+
# - forward through model (autocast is established by Trainer)
|
| 143 |
+
# - return loss (grads flow from return)
|
| 144 |
+
# ------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
+
def training_step(self, batch: torch.Tensor, batch_idx: int):
|
| 147 |
+
# DataLoader produces (B, T+1) rows; split into input/target.
|
| 148 |
+
# Lightning's default collate already moved batch to self.device via
|
| 149 |
+
# the accelerator callback when pin_memory=True and device != cpu.
|
| 150 |
+
if batch.dim() != 2:
|
| 151 |
+
raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}")
|
| 152 |
+
x = batch[:, :-1].contiguous()
|
| 153 |
+
y = batch[:, 1:].contiguous()
|
| 154 |
+
|
| 155 |
+
loss = self.model(x, y)
|
| 156 |
+
# Lightning applies the grad-accum divisor automatically; we just
|
| 157 |
+
# return the raw loss. loss.detach() is stored for logging.
|
| 158 |
+
self._log_step(loss.detach(), y)
|
| 159 |
+
return loss
|
| 160 |
+
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
# Optimizer step hook: update LR / momentum / WD using time-progress.
|
| 163 |
+
# Runs once per optimizer step (after all accum micro-batches).
|
| 164 |
+
# ------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
|
| 167 |
+
# Update schedules from wall-clock progress.
|
| 168 |
+
now = time.time()
|
| 169 |
+
if self._train_start_time is None:
|
| 170 |
+
self._train_start_time = now
|
| 171 |
+
self._last_step_end = now
|
| 172 |
+
progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0)
|
| 173 |
+
|
| 174 |
+
step = self.global_step
|
| 175 |
+
lrm = _lr_multiplier(progress)
|
| 176 |
+
mom = _muon_momentum(step)
|
| 177 |
+
wd = _weight_decay(progress)
|
| 178 |
+
for group in optimizer.param_groups:
|
| 179 |
+
group["lr"] = group["initial_lr"] * lrm
|
| 180 |
+
if group.get("kind") == "muon":
|
| 181 |
+
group["momentum"] = mom
|
| 182 |
+
group["weight_decay"] = wd
|
| 183 |
+
|
| 184 |
+
# Grad clip (matches legacy loop). Lightning provides this via
|
| 185 |
+
# Trainer(gradient_clip_val=1.0) but we want the exact call-site.
|
| 186 |
+
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
| 187 |
+
|
| 188 |
+
# Hyena train-cache: we must flush accumulated micro-batch grads BACK
|
| 189 |
+
# into the filter MLP params AFTER the accum-backward closure has run
|
| 190 |
+
# but BEFORE the optimizer actually consumes the grads. Lightning
|
| 191 |
+
# composes these so the closure runs inside optimizer.step(). We wrap
|
| 192 |
+
# the closure to insert our flush at the exact right moment.
|
| 193 |
+
#
|
| 194 |
+
# Ordering within the wrapped closure:
|
| 195 |
+
# 1. optimizer_closure() — runs all micro-batch forwards + backwards.
|
| 196 |
+
# Each Hyena micro-batch backward accumulates into _k_leaf.grad.
|
| 197 |
+
# 2. flush_hyena_pending_grads() — one-shot
|
| 198 |
+
# torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter.
|
| 199 |
+
# Now filter MLP / pos_emb / bias params have their correct grads.
|
| 200 |
+
#
|
| 201 |
+
# No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist.
|
| 202 |
+
_has_flush = hasattr(self.model, "flush_hyena_pending_grads")
|
| 203 |
+
if _has_flush:
|
| 204 |
+
_orig_closure = optimizer_closure
|
| 205 |
+
|
| 206 |
+
def _wrapped_closure():
|
| 207 |
+
result = _orig_closure()
|
| 208 |
+
self.model.flush_hyena_pending_grads()
|
| 209 |
+
return result
|
| 210 |
+
|
| 211 |
+
effective_closure = _wrapped_closure
|
| 212 |
+
else:
|
| 213 |
+
effective_closure = optimizer_closure
|
| 214 |
+
|
| 215 |
+
# Run the step (this is what Lightning would have done for us).
|
| 216 |
+
optimizer.step(closure=effective_closure)
|
| 217 |
+
self.model.zero_grad(set_to_none=True)
|
| 218 |
+
|
| 219 |
+
# Hyena filter-rfft cache invalidation. No-op if:
|
| 220 |
+
# (a) no Hyena layers are in the model, or
|
| 221 |
+
# (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0
|
| 222 |
+
# (the operators never populated either cache)
|
| 223 |
+
# In either case this is a handful of Python attribute resets.
|
| 224 |
+
if hasattr(self.model, "invalidate_hyena_caches"):
|
| 225 |
+
self.model.invalidate_hyena_caches()
|
| 226 |
+
|
| 227 |
+
# Hestia QAT snap every N steps. Temperature anneals every step.
|
| 228 |
+
progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0)
|
| 229 |
+
self.model.hestia.anneal_temperature(progress_now)
|
| 230 |
+
if self._hestia_interval > 0 and step % self._hestia_interval == 0:
|
| 231 |
+
self.model.hestia.apply_to(self.model)
|
| 232 |
+
|
| 233 |
+
# SDR SOM update when the model stashed an sdr in the last forward.
|
| 234 |
+
_last_sdr = getattr(self.model, "_last_sdr", None)
|
| 235 |
+
if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"):
|
| 236 |
+
# x from the last training_step is not available here without
|
| 237 |
+
# captured state; the legacy loop passed (x, _last_sdr). To keep
|
| 238 |
+
# the interface clean we pass the last batch's x via a buffer.
|
| 239 |
+
# Since _last_sdr is derived from idx, we reuse self._last_x.
|
| 240 |
+
if getattr(self, "_last_x", None) is not None:
|
| 241 |
+
self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr)
|
| 242 |
+
|
| 243 |
+
# Advance the wall-clock counter for LR schedule (matches legacy
|
| 244 |
+
# behavior which incremented only after the first warm-up step).
|
| 245 |
+
dt = now - (self._last_step_end or now)
|
| 246 |
+
self._last_step_end = now
|
| 247 |
+
if step > 10:
|
| 248 |
+
self._total_training_time += dt
|
| 249 |
+
|
| 250 |
+
# ------------------------------------------------------------------
|
| 251 |
+
# Logging — mirrors the step=NNNNN line format of the legacy loop so
|
| 252 |
+
# grep/tee pipelines keep working.
|
| 253 |
+
# ------------------------------------------------------------------
|
| 254 |
+
|
| 255 |
+
def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None:
|
| 256 |
+
# Stash the current x so optimizer_step can drive SOM update.
|
| 257 |
+
self._last_x = None # reset; we will set it below.
|
| 258 |
+
# We don't have x here (already discarded); emit a None marker that
|
| 259 |
+
# the SOM hook will silently skip if absent.
|
| 260 |
+
|
| 261 |
+
loss_f = float(loss.item())
|
| 262 |
+
if not math.isfinite(loss_f) or loss_f > 100:
|
| 263 |
+
# Let Lightning raise / the trainer callbacks handle this.
|
| 264 |
+
self.log("train_loss_nan", 1.0)
|
| 265 |
+
return
|
| 266 |
+
|
| 267 |
+
step = self.global_step
|
| 268 |
+
self._smooth_loss = (
|
| 269 |
+
self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f
|
| 270 |
+
)
|
| 271 |
+
debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9)
|
| 272 |
+
dt = max(time.time() - (self._last_step_end or time.time()), 1e-6)
|
| 273 |
+
tps = int(self._tokens_per_step / dt) if dt > 0 else 0
|
| 274 |
+
mfu = (
|
| 275 |
+
100.0
|
| 276 |
+
* self._flops_per_token
|
| 277 |
+
* self._tokens_per_step
|
| 278 |
+
/ dt
|
| 279 |
+
/ GPU_BF16_PEAK_FLOPS
|
| 280 |
+
if dt > 0
|
| 281 |
+
else 0.0
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# bpb live: y flat -> token_bytes LUT -> avg bytes/token
|
| 285 |
+
bpt = debiased / math.log(2)
|
| 286 |
+
if self._token_bytes is not None:
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
y_flat = y.reshape(-1)
|
| 289 |
+
nbytes = self._token_bytes[y_flat]
|
| 290 |
+
mask = nbytes > 0
|
| 291 |
+
denom = mask.sum().clamp(min=1).float()
|
| 292 |
+
avg_bpt = (nbytes.float() * mask.float()).sum() / denom
|
| 293 |
+
bpt_batch = float(avg_bpt.item())
|
| 294 |
+
if step == 0 or self._bpt_ema <= 0.0:
|
| 295 |
+
self._bpt_ema = bpt_batch
|
| 296 |
+
else:
|
| 297 |
+
self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch
|
| 298 |
+
bpb = bpt / max(self._bpt_ema, 1e-6)
|
| 299 |
+
vram = (
|
| 300 |
+
torch.cuda.memory_allocated() / 1024 / 1024
|
| 301 |
+
if torch.cuda.is_available()
|
| 302 |
+
else 0.0
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.log_dict(
|
| 306 |
+
{
|
| 307 |
+
"train/loss": debiased,
|
| 308 |
+
"train/bpb": bpb,
|
| 309 |
+
"train/bpt": bpt,
|
| 310 |
+
"train/tps": float(tps),
|
| 311 |
+
"train/mfu": float(mfu),
|
| 312 |
+
"train/vram_mib": float(vram),
|
| 313 |
+
},
|
| 314 |
+
prog_bar=False,
|
| 315 |
+
on_step=True,
|
| 316 |
+
on_epoch=False,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..."
|
| 320 |
+
print(
|
| 321 |
+
f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} "
|
| 322 |
+
f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} "
|
| 323 |
+
f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
|
| 324 |
+
f"vram={vram:.0f}MiB",
|
| 325 |
+
flush=True,
|
| 326 |
+
)
|
overlay/hydra/model.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
overlay/hydra/optimizer.py
CHANGED
|
@@ -1,252 +1,252 @@
|
|
| 1 |
-
"""MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
|
| 2 |
-
|
| 3 |
-
Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
|
| 4 |
-
|
| 5 |
-
F1-F15 state preserved:
|
| 6 |
-
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
|
| 7 |
-
step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
|
| 8 |
-
fresh. Persistent copies of param storage would be mutated between forward
|
| 9 |
-
passes (via lerp_/sub_ on stacked tensors that share storage with params),
|
| 10 |
-
triggering "modified in-place" errors on grad_accum=2 backwards.
|
| 11 |
-
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
|
| 12 |
-
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
|
| 13 |
-
dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
|
| 14 |
-
stream-capture deadlock. See .omc/muon_compile_bug.md for the full
|
| 15 |
-
investigation.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
from __future__ import annotations
|
| 19 |
-
|
| 20 |
-
import os
|
| 21 |
-
|
| 22 |
-
import torch
|
| 23 |
-
|
| 24 |
-
# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
|
| 25 |
-
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
|
| 26 |
-
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
polar_express_coeffs = [
|
| 30 |
-
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 31 |
-
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 32 |
-
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 33 |
-
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 34 |
-
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 35 |
-
]
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
|
| 39 |
-
# Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
|
| 40 |
-
# for the whole group) driven from MuonAdamW._step_adamw below.
|
| 41 |
-
grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
|
| 42 |
-
p.mul_(1 - lr_t * wd_t)
|
| 43 |
-
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 44 |
-
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 45 |
-
bias1 = 1 - beta1_t ** step_t
|
| 46 |
-
bias2 = 1 - beta2_t ** step_t
|
| 47 |
-
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 48 |
-
step_size = lr_t / bias1
|
| 49 |
-
p.add_(exp_avg / denom, alpha=-step_size)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# ---------------------------------------------------------------------------
|
| 53 |
-
# F15 muon_step_fused compile strategy.
|
| 54 |
-
#
|
| 55 |
-
# HYDRA_MUON_COMPILE env gate:
|
| 56 |
-
# "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
|
| 57 |
-
# Dynamic=True collapses the per-shape specialization cache so that N
|
| 58 |
-
# Muon param-groups with N distinct shapes trigger 1 compile, not N.
|
| 59 |
-
# mode="default" keeps the inductor codegen but disables cudagraphs,
|
| 60 |
-
# which is what caused the step-17→18 silent deadlock observed under
|
| 61 |
-
# the original dynamic=False configuration: cudagraph stream capture
|
| 62 |
-
# can deadlock against HTM's CUDA kernels running on the default
|
| 63 |
-
# stream, and the failure mode at capture-time is a silent hang
|
| 64 |
-
# (100% GPU util, no log output, process state R).
|
| 65 |
-
# "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
|
| 66 |
-
# Keeps an escape hatch in case a future torch/inductor regression
|
| 67 |
-
# reintroduces a deadlock.
|
| 68 |
-
#
|
| 69 |
-
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
|
| 70 |
-
# alias-analysis edge case where inductor sees `g is stacked_grads` and
|
| 71 |
-
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
|
| 72 |
-
# ---------------------------------------------------------------------------
|
| 73 |
-
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
|
| 74 |
-
|
| 75 |
-
def _maybe_compile(fn):
|
| 76 |
-
if _MUON_COMPILE:
|
| 77 |
-
# mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
|
| 78 |
-
# would enable) to avoid stream-capture deadlocks against HTM's CUDA
|
| 79 |
-
# kernels. dynamic=True minimizes recompile count across param-group
|
| 80 |
-
# shapes.
|
| 81 |
-
return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
|
| 82 |
-
return fn
|
| 83 |
-
|
| 84 |
-
@_maybe_compile
|
| 85 |
-
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
|
| 86 |
-
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
|
| 87 |
-
# Cast grads to param dtype AND clone defensively to break any alias
|
| 88 |
-
# between the (freshly-stacked) input and the in-place lerp_ below.
|
| 89 |
-
# Without this, inductor's alias analysis can emit code that reads from
|
| 90 |
-
# post-mutation storage when computing `v_mean = g.square().mean(...)`.
|
| 91 |
-
stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
|
| 92 |
-
# Nesterov momentum
|
| 93 |
-
momentum = momentum_t.to(stacked_grads.dtype)
|
| 94 |
-
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 95 |
-
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 96 |
-
# Polar express orthogonalization
|
| 97 |
-
X = g.bfloat16()
|
| 98 |
-
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
| 99 |
-
if g.size(-2) > g.size(-1):
|
| 100 |
-
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 101 |
-
A = X.mT @ X
|
| 102 |
-
B = b * A + c * (A @ A)
|
| 103 |
-
X = a * X + X @ B
|
| 104 |
-
else:
|
| 105 |
-
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 106 |
-
A = X @ X.mT
|
| 107 |
-
B = b * A + c * (A @ A)
|
| 108 |
-
X = a * X + B @ X
|
| 109 |
-
g = X
|
| 110 |
-
# NorMuon variance reduction
|
| 111 |
-
# Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
|
| 112 |
-
# float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
|
| 113 |
-
beta2 = beta2_t.to(second_momentum_buffer.dtype)
|
| 114 |
-
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 115 |
-
red_dim_size = g.size(red_dim)
|
| 116 |
-
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 117 |
-
v_norm = v_norm_sq.sqrt()
|
| 118 |
-
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 119 |
-
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 120 |
-
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 121 |
-
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 122 |
-
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 123 |
-
g = g * final_scale.to(g.dtype)
|
| 124 |
-
# Cautious weight decay + parameter update
|
| 125 |
-
lr = lr_t.to(g.dtype)
|
| 126 |
-
wd = wd_t.to(g.dtype)
|
| 127 |
-
mask = (g * stacked_params) >= 0
|
| 128 |
-
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
class MuonAdamW(torch.optim.Optimizer):
|
| 132 |
-
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
|
| 133 |
-
|
| 134 |
-
def __init__(self, param_groups):
|
| 135 |
-
super().__init__(param_groups, defaults={})
|
| 136 |
-
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 137 |
-
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 138 |
-
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 139 |
-
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 140 |
-
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 141 |
-
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 142 |
-
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 143 |
-
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 144 |
-
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 145 |
-
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 146 |
-
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 147 |
-
|
| 148 |
-
def _step_adamw(self, group):
|
| 149 |
-
params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
|
| 150 |
-
for p in group['params']:
|
| 151 |
-
if p.grad is None:
|
| 152 |
-
continue
|
| 153 |
-
state = self.state[p]
|
| 154 |
-
if not state:
|
| 155 |
-
state['step'] = 0
|
| 156 |
-
state['exp_avg'] = torch.zeros_like(p)
|
| 157 |
-
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 158 |
-
if 'step_t' not in state:
|
| 159 |
-
# _fused_adamw_ wants a per-param float step tensor on-device.
|
| 160 |
-
state['step_t'] = torch.tensor(
|
| 161 |
-
float(state['step']), dtype=torch.float32, device=p.device
|
| 162 |
-
)
|
| 163 |
-
state['step'] += 1
|
| 164 |
-
params.append(p)
|
| 165 |
-
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
|
| 166 |
-
exp_avgs.append(state['exp_avg'])
|
| 167 |
-
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 168 |
-
state_steps.append(state['step_t'])
|
| 169 |
-
|
| 170 |
-
if not params:
|
| 171 |
-
return
|
| 172 |
-
|
| 173 |
-
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
|
| 174 |
-
# _fused_adamw_ needs uniform (device, dtype) within a call, so
|
| 175 |
-
# group by (device, dtype) — same pattern as PyTorch's own
|
| 176 |
-
# AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
|
| 177 |
-
buckets = {}
|
| 178 |
-
for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
|
| 179 |
-
key = (p.device, p.dtype)
|
| 180 |
-
buckets.setdefault(key, ([], [], [], [], []))
|
| 181 |
-
b_p, b_g, b_ea, b_es, b_st = buckets[key]
|
| 182 |
-
b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
|
| 183 |
-
|
| 184 |
-
lr_f = float(group['lr'])
|
| 185 |
-
b1_f = float(group['betas'][0])
|
| 186 |
-
b2_f = float(group['betas'][1])
|
| 187 |
-
wd_f = float(group['weight_decay'])
|
| 188 |
-
eps_f = float(group['eps'])
|
| 189 |
-
for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
|
| 190 |
-
torch._foreach_add_(b_st, 1.0)
|
| 191 |
-
torch._fused_adamw_(
|
| 192 |
-
b_p, b_g, b_ea, b_es,
|
| 193 |
-
[], # max_exp_avg_sqs unused (amsgrad=False)
|
| 194 |
-
b_st,
|
| 195 |
-
amsgrad=False,
|
| 196 |
-
lr=lr_f, beta1=b1_f, beta2=b2_f,
|
| 197 |
-
weight_decay=wd_f, eps=eps_f,
|
| 198 |
-
maximize=False,
|
| 199 |
-
grad_scale=None, found_inf=None,
|
| 200 |
-
)
|
| 201 |
-
return
|
| 202 |
-
|
| 203 |
-
# Fallback per-param path.
|
| 204 |
-
self._adamw_lr_t.fill_(group['lr'])
|
| 205 |
-
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 206 |
-
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 207 |
-
self._adamw_eps_t.fill_(group['eps'])
|
| 208 |
-
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 209 |
-
for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
|
| 210 |
-
self._adamw_step_t.fill_(self.state[p]['step'])
|
| 211 |
-
adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
|
| 212 |
-
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 213 |
-
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
| 214 |
-
|
| 215 |
-
def _step_muon(self, group):
|
| 216 |
-
params = [p for p in group['params'] if p.grad is not None]
|
| 217 |
-
if not params:
|
| 218 |
-
return
|
| 219 |
-
p = params[0]
|
| 220 |
-
state = self.state[p]
|
| 221 |
-
num_params = len(params)
|
| 222 |
-
shape, device, dtype = p.shape, p.device, p.dtype
|
| 223 |
-
if "momentum_buffer" not in state:
|
| 224 |
-
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 225 |
-
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 226 |
-
if "second_momentum_buffer" not in state:
|
| 227 |
-
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
|
| 228 |
-
full_shape = (num_params, *shape)
|
| 229 |
-
state_shape = list(full_shape)
|
| 230 |
-
state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
|
| 231 |
-
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 232 |
-
# F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
|
| 233 |
-
# This was the autograd-safety fix that unblocks grad_accum>=2.
|
| 234 |
-
stacked_grads = torch.stack([p.grad for p in params])
|
| 235 |
-
stacked_params = torch.stack(params)
|
| 236 |
-
self._muon_momentum_t.fill_(group["momentum"])
|
| 237 |
-
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 238 |
-
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
|
| 239 |
-
self._muon_wd_t.fill_(group["weight_decay"])
|
| 240 |
-
muon_step_fused(stacked_grads, stacked_params,
|
| 241 |
-
state["momentum_buffer"], state["second_momentum_buffer"],
|
| 242 |
-
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
|
| 243 |
-
self._muon_beta2_t, group["ns_steps"], red_dim)
|
| 244 |
-
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 245 |
-
|
| 246 |
-
@torch.no_grad()
|
| 247 |
-
def step(self):
|
| 248 |
-
for group in self.param_groups:
|
| 249 |
-
if group['kind'] == 'adamw':
|
| 250 |
-
self._step_adamw(group)
|
| 251 |
-
elif group['kind'] == 'muon':
|
| 252 |
-
self._step_muon(group)
|
|
|
|
| 1 |
+
"""MuonAdamW optimizer — combined Muon (2D matrices) + AdamW (everything else).
|
| 2 |
+
|
| 3 |
+
Extracted verbatim from train.py (W1 modularization). Semantics unchanged.
|
| 4 |
+
|
| 5 |
+
F1-F15 state preserved:
|
| 6 |
+
- F7 REVERTED: `stacked_params_buf` persistent across steps was REMOVED — each
|
| 7 |
+
step calls `torch.stack([p.grad for p in params])` / `torch.stack(params)`
|
| 8 |
+
fresh. Persistent copies of param storage would be mutated between forward
|
| 9 |
+
passes (via lerp_/sub_ on stacked tensors that share storage with params),
|
| 10 |
+
triggering "modified in-place" errors on grad_accum=2 backwards.
|
| 11 |
+
- F11/F15: `@torch.compile` on `adamw_step_fused` / `muon_step_fused` intact.
|
| 12 |
+
- F15 compile is default-ON (HYDRA_MUON_COMPILE=1), configured with
|
| 13 |
+
dynamic=True + mode="default" to avoid the step-17→18 cudagraphs
|
| 14 |
+
stream-capture deadlock. See .omc/muon_compile_bug.md for the full
|
| 15 |
+
investigation.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
# HYDRA_FUSED_ADAMW=1 (default) -> vectorized torch._fused_adamw_ kernel.
|
| 25 |
+
_HYDRA_FUSED_ADAMW = os.environ.get("HYDRA_FUSED_ADAMW", "1") == "1"
|
| 26 |
+
_HAS_FUSED_ADAMW = hasattr(torch, "_fused_adamw_")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
polar_express_coeffs = [
|
| 30 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 31 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 32 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 33 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 34 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
|
| 39 |
+
# Per-param AdamW fallback. Fast path is torch._fused_adamw_ (1 CUDA launch
|
| 40 |
+
# for the whole group) driven from MuonAdamW._step_adamw below.
|
| 41 |
+
grad = grad.to(p.dtype) # handle mixed bf16/fp32 from autocast
|
| 42 |
+
p.mul_(1 - lr_t * wd_t)
|
| 43 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 44 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 45 |
+
bias1 = 1 - beta1_t ** step_t
|
| 46 |
+
bias2 = 1 - beta2_t ** step_t
|
| 47 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 48 |
+
step_size = lr_t / bias1
|
| 49 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# F15 muon_step_fused compile strategy.
|
| 54 |
+
#
|
| 55 |
+
# HYDRA_MUON_COMPILE env gate:
|
| 56 |
+
# "1" (default ON) — wrap with torch.compile(dynamic=True, mode="default").
|
| 57 |
+
# Dynamic=True collapses the per-shape specialization cache so that N
|
| 58 |
+
# Muon param-groups with N distinct shapes trigger 1 compile, not N.
|
| 59 |
+
# mode="default" keeps the inductor codegen but disables cudagraphs,
|
| 60 |
+
# which is what caused the step-17→18 silent deadlock observed under
|
| 61 |
+
# the original dynamic=False configuration: cudagraph stream capture
|
| 62 |
+
# can deadlock against HTM's CUDA kernels running on the default
|
| 63 |
+
# stream, and the failure mode at capture-time is a silent hang
|
| 64 |
+
# (100% GPU util, no log output, process state R).
|
| 65 |
+
# "0" — fall back to eager Python (slower, ~43k tps vs ~63k compiled).
|
| 66 |
+
# Keeps an escape hatch in case a future torch/inductor regression
|
| 67 |
+
# reintroduces a deadlock.
|
| 68 |
+
#
|
| 69 |
+
# Defensive .clone() on stacked_grads before in-place lerp_ eliminates the
|
| 70 |
+
# alias-analysis edge case where inductor sees `g is stacked_grads` and
|
| 71 |
+
# subsequent `stacked_grads.square()` operating on the post-lerp storage.
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
_MUON_COMPILE = os.environ.get("HYDRA_MUON_COMPILE", "1") == "1"
|
| 74 |
+
|
| 75 |
+
def _maybe_compile(fn):
|
| 76 |
+
if _MUON_COMPILE:
|
| 77 |
+
# mode="default" explicitly opts OUT of cudagraphs (which reduce-overhead
|
| 78 |
+
# would enable) to avoid stream-capture deadlocks against HTM's CUDA
|
| 79 |
+
# kernels. dynamic=True minimizes recompile count across param-group
|
| 80 |
+
# shapes.
|
| 81 |
+
return torch.compile(fn, fullgraph=False, dynamic=True, mode="default")
|
| 82 |
+
return fn
|
| 83 |
+
|
| 84 |
+
@_maybe_compile
|
| 85 |
+
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
|
| 86 |
+
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
|
| 87 |
+
# Cast grads to param dtype AND clone defensively to break any alias
|
| 88 |
+
# between the (freshly-stacked) input and the in-place lerp_ below.
|
| 89 |
+
# Without this, inductor's alias analysis can emit code that reads from
|
| 90 |
+
# post-mutation storage when computing `v_mean = g.square().mean(...)`.
|
| 91 |
+
stacked_grads = stacked_grads.to(momentum_buffer.dtype).clone()
|
| 92 |
+
# Nesterov momentum
|
| 93 |
+
momentum = momentum_t.to(device=momentum_buffer.device, dtype=stacked_grads.dtype)
|
| 94 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 95 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 96 |
+
# Polar express orthogonalization
|
| 97 |
+
X = g.bfloat16()
|
| 98 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
| 99 |
+
if g.size(-2) > g.size(-1):
|
| 100 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 101 |
+
A = X.mT @ X
|
| 102 |
+
B = b * A + c * (A @ A)
|
| 103 |
+
X = a * X + X @ B
|
| 104 |
+
else:
|
| 105 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 106 |
+
A = X @ X.mT
|
| 107 |
+
B = b * A + c * (A @ A)
|
| 108 |
+
X = a * X + B @ X
|
| 109 |
+
g = X
|
| 110 |
+
# NorMuon variance reduction
|
| 111 |
+
# Keep beta2 in the state-buffer dtype, not g.dtype, so lerp_ on the
|
| 112 |
+
# float32 second_momentum_buffer doesn't hit a dtype mismatch on h200.
|
| 113 |
+
beta2 = beta2_t.to(device=second_momentum_buffer.device, dtype=second_momentum_buffer.dtype)
|
| 114 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 115 |
+
red_dim_size = g.size(red_dim)
|
| 116 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 117 |
+
v_norm = v_norm_sq.sqrt()
|
| 118 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 119 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 120 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 121 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 122 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 123 |
+
g = g * final_scale.to(g.dtype)
|
| 124 |
+
# Cautious weight decay + parameter update
|
| 125 |
+
lr = lr_t.to(device=stacked_params.device, dtype=g.dtype)
|
| 126 |
+
wd = wd_t.to(device=stacked_params.device, dtype=g.dtype)
|
| 127 |
+
mask = (g * stacked_params) >= 0
|
| 128 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class MuonAdamW(torch.optim.Optimizer):
|
| 132 |
+
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
|
| 133 |
+
|
| 134 |
+
def __init__(self, param_groups):
|
| 135 |
+
super().__init__(param_groups, defaults={})
|
| 136 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 137 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 138 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 139 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 140 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 141 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 142 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 143 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 144 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 145 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 146 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 147 |
+
|
| 148 |
+
def _step_adamw(self, group):
|
| 149 |
+
params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
|
| 150 |
+
for p in group['params']:
|
| 151 |
+
if p.grad is None:
|
| 152 |
+
continue
|
| 153 |
+
state = self.state[p]
|
| 154 |
+
if not state:
|
| 155 |
+
state['step'] = 0
|
| 156 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 157 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 158 |
+
if 'step_t' not in state:
|
| 159 |
+
# _fused_adamw_ wants a per-param float step tensor on-device.
|
| 160 |
+
state['step_t'] = torch.tensor(
|
| 161 |
+
float(state['step']), dtype=torch.float32, device=p.device
|
| 162 |
+
)
|
| 163 |
+
state['step'] += 1
|
| 164 |
+
params.append(p)
|
| 165 |
+
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
|
| 166 |
+
exp_avgs.append(state['exp_avg'])
|
| 167 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 168 |
+
state_steps.append(state['step_t'])
|
| 169 |
+
|
| 170 |
+
if not params:
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
|
| 174 |
+
# _fused_adamw_ needs uniform (device, dtype) within a call, so
|
| 175 |
+
# group by (device, dtype) — same pattern as PyTorch's own
|
| 176 |
+
# AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
|
| 177 |
+
buckets = {}
|
| 178 |
+
for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
|
| 179 |
+
key = (p.device, p.dtype)
|
| 180 |
+
buckets.setdefault(key, ([], [], [], [], []))
|
| 181 |
+
b_p, b_g, b_ea, b_es, b_st = buckets[key]
|
| 182 |
+
b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
|
| 183 |
+
|
| 184 |
+
lr_f = float(group['lr'])
|
| 185 |
+
b1_f = float(group['betas'][0])
|
| 186 |
+
b2_f = float(group['betas'][1])
|
| 187 |
+
wd_f = float(group['weight_decay'])
|
| 188 |
+
eps_f = float(group['eps'])
|
| 189 |
+
for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
|
| 190 |
+
torch._foreach_add_(b_st, 1.0)
|
| 191 |
+
torch._fused_adamw_(
|
| 192 |
+
b_p, b_g, b_ea, b_es,
|
| 193 |
+
[], # max_exp_avg_sqs unused (amsgrad=False)
|
| 194 |
+
b_st,
|
| 195 |
+
amsgrad=False,
|
| 196 |
+
lr=lr_f, beta1=b1_f, beta2=b2_f,
|
| 197 |
+
weight_decay=wd_f, eps=eps_f,
|
| 198 |
+
maximize=False,
|
| 199 |
+
grad_scale=None, found_inf=None,
|
| 200 |
+
)
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
# Fallback per-param path.
|
| 204 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 205 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 206 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 207 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 208 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 209 |
+
for p, grad, exp_avg, exp_avg_sq in zip(params, grads, exp_avgs, exp_avg_sqs):
|
| 210 |
+
self._adamw_step_t.fill_(self.state[p]['step'])
|
| 211 |
+
adamw_step_fused(p, grad, exp_avg, exp_avg_sq,
|
| 212 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 213 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
| 214 |
+
|
| 215 |
+
def _step_muon(self, group):
|
| 216 |
+
params = [p for p in group['params'] if p.grad is not None]
|
| 217 |
+
if not params:
|
| 218 |
+
return
|
| 219 |
+
p = params[0]
|
| 220 |
+
state = self.state[p]
|
| 221 |
+
num_params = len(params)
|
| 222 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 223 |
+
if "momentum_buffer" not in state:
|
| 224 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 225 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 226 |
+
if "second_momentum_buffer" not in state:
|
| 227 |
+
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
|
| 228 |
+
full_shape = (num_params, *shape)
|
| 229 |
+
state_shape = list(full_shape)
|
| 230 |
+
state_shape[len(state_shape) + red_dim] = 1 # red_dim is negative
|
| 231 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 232 |
+
# F7 REVERT: fresh stacks each step (no persistent stacked_params_buf).
|
| 233 |
+
# This was the autograd-safety fix that unblocks grad_accum>=2.
|
| 234 |
+
stacked_grads = torch.stack([p.grad for p in params])
|
| 235 |
+
stacked_params = torch.stack(params)
|
| 236 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 237 |
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 238 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1]) ** 0.5)
|
| 239 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 240 |
+
muon_step_fused(stacked_grads, stacked_params,
|
| 241 |
+
state["momentum_buffer"], state["second_momentum_buffer"],
|
| 242 |
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
|
| 243 |
+
self._muon_beta2_t, group["ns_steps"], red_dim)
|
| 244 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 245 |
+
|
| 246 |
+
@torch.no_grad()
|
| 247 |
+
def step(self):
|
| 248 |
+
for group in self.param_groups:
|
| 249 |
+
if group['kind'] == 'adamw':
|
| 250 |
+
self._step_adamw(group)
|
| 251 |
+
elif group['kind'] == 'muon':
|
| 252 |
+
self._step_muon(group)
|
overlay/hydra/reality_bridge.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class RealityBridgeOutput:
|
| 11 |
+
reality: torch.Tensor
|
| 12 |
+
poincare: torch.Tensor
|
| 13 |
+
l0_indices: torch.Tensor
|
| 14 |
+
l0_values: torch.Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RealityPoincareBridge(nn.Module):
|
| 18 |
+
"""Default-off SEM-Claw continuous→discrete bridge.
|
| 19 |
+
|
| 20 |
+
PyTorch GEMM creates a compact 133-d reality latent, then a differentiable
|
| 21 |
+
Poincare-disk projection is kept for metrics/regularizers while a detached
|
| 22 |
+
int16 L0/top-k index buffer feeds Engram/Cantor sparse retrieval. This is a
|
| 23 |
+
production-shaped version of rs.md's Poincare/Reality Buffer without adding
|
| 24 |
+
speculative E7 machinery to the hot path.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
d_model: int,
|
| 30 |
+
d_reality: int = 133,
|
| 31 |
+
d_poincare: int = 2,
|
| 32 |
+
l0_k: int = 64,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
if d_model <= 0:
|
| 36 |
+
raise ValueError(f"d_model must be positive, got {d_model}")
|
| 37 |
+
if d_reality <= 0:
|
| 38 |
+
raise ValueError(f"d_reality must be positive, got {d_reality}")
|
| 39 |
+
if d_poincare != 2:
|
| 40 |
+
raise ValueError("Poincare bridge currently expects d_poincare=2")
|
| 41 |
+
if l0_k <= 0:
|
| 42 |
+
raise ValueError(f"l0_k must be positive, got {l0_k}")
|
| 43 |
+
self.d_model = int(d_model)
|
| 44 |
+
self.d_reality = int(d_reality)
|
| 45 |
+
self.d_poincare = int(d_poincare)
|
| 46 |
+
self.l0_k = min(int(l0_k), self.d_reality)
|
| 47 |
+
self.to_reality = nn.Linear(d_model, d_reality, bias=False)
|
| 48 |
+
self.to_tangent2 = nn.Linear(d_reality, d_poincare, bias=False)
|
| 49 |
+
nn.init.normal_(self.to_reality.weight, mean=0.0, std=0.02)
|
| 50 |
+
nn.init.normal_(self.to_tangent2.weight, mean=0.0, std=0.02)
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def poincare_expmap0(tangent2: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
| 54 |
+
t = tangent2.float()
|
| 55 |
+
r = t.norm(dim=-1, keepdim=True).clamp_min(eps)
|
| 56 |
+
y = torch.tanh(r) * (t / r)
|
| 57 |
+
return y.to(tangent2.dtype)
|
| 58 |
+
|
| 59 |
+
def forward(self, x: torch.Tensor) -> RealityBridgeOutput:
|
| 60 |
+
if x.shape[-1] != self.d_model:
|
| 61 |
+
raise ValueError(f"expected last dim {self.d_model}, got {x.shape[-1]}")
|
| 62 |
+
reality = self.to_reality(x)
|
| 63 |
+
tangent2 = self.to_tangent2(reality)
|
| 64 |
+
poincare = self.poincare_expmap0(tangent2)
|
| 65 |
+
vals, idx = reality.float().abs().topk(self.l0_k, dim=-1)
|
| 66 |
+
return RealityBridgeOutput(
|
| 67 |
+
reality=reality,
|
| 68 |
+
poincare=poincare,
|
| 69 |
+
l0_indices=idx.to(torch.int16),
|
| 70 |
+
l0_values=vals.to(reality.dtype),
|
| 71 |
+
)
|
overlay/hydra/training.py
CHANGED
|
@@ -1,948 +1,967 @@
|
|
| 1 |
-
"""HYDRA training entry: setup, train loop, eval, summary.
|
| 2 |
-
|
| 3 |
-
Extracted from the monolithic train.py (W1 modularization). Semantics
|
| 4 |
-
preserved. Public entrypoint: `main()`.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import gc
|
| 10 |
-
import json
|
| 11 |
-
import math
|
| 12 |
-
import os
|
| 13 |
-
import sys
|
| 14 |
-
import threading
|
| 15 |
-
import time
|
| 16 |
-
from dataclasses import asdict
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
|
| 19 |
-
import torch
|
| 20 |
-
|
| 21 |
-
# Line-buffered stdout so `python -u train.py | tee run.log | grep step` is
|
| 22 |
-
# live (no \r overwrite, no 4k block-buffered pipe stalls). Safe on Python
|
| 23 |
-
# 3.7+ where io.TextIOWrapper.reconfigure exists.
|
| 24 |
-
try:
|
| 25 |
-
sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined]
|
| 26 |
-
except Exception:
|
| 27 |
-
pass
|
| 28 |
-
|
| 29 |
-
from hydra.config import (
|
| 30 |
-
ADAM_BETAS, CURRICULUM_SHORT_SEQ_LEN, CURRICULUM_SHORT_STEPS,
|
| 31 |
-
D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMA_DECAY, EMBEDDING_LR,
|
| 32 |
-
ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
|
| 33 |
-
FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
|
| 34 |
-
N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
|
| 35 |
-
UNEMBEDDING_LR, USE_EMA, WARMUP_RATIO, WEIGHT_DECAY,
|
| 36 |
-
)
|
| 37 |
-
from hydra.diffusion_loss import mdlm_masked_forward_process, mdlm_rb_loss
|
| 38 |
-
from hydra.eval import run_factual_english, run_factual_probes
|
| 39 |
-
from hydra.model import PostSemClawModel
|
| 40 |
-
|
| 41 |
-
import prepare as _prepare_mod
|
| 42 |
-
from prepare import MAX_SEQ_LEN, TIME_BUDGET as _TIME_BUDGET, Tokenizer, evaluate_bpb as _evaluate_bpb_shards, get_token_bytes, make_dataloader as _make_dataloader_shards
|
| 43 |
-
|
| 44 |
-
# Streaming Nemotron path (Super3 recipe). Opt-in via HYDRA_USE_NEMOTRON=1.
|
| 45 |
-
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 46 |
-
import prepare_nemotron as _p_nemo
|
| 47 |
-
make_dataloader = _p_nemo.make_dataloader
|
| 48 |
-
evaluate_bpb = _p_nemo.evaluate_bpb
|
| 49 |
-
else:
|
| 50 |
-
make_dataloader = _make_dataloader_shards
|
| 51 |
-
evaluate_bpb = _evaluate_bpb_shards
|
| 52 |
-
|
| 53 |
-
TIME_BUDGET = int(os.environ.get("HYDRA_TIME_BUDGET", str(_TIME_BUDGET)))
|
| 54 |
-
_prepare_mod.TIME_BUDGET = TIME_BUDGET # sync for evaluate_bpb
|
| 55 |
-
|
| 56 |
-
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
|
| 57 |
-
LATEST_CKPT = CACHE_DIR / "latest.pt"
|
| 58 |
-
PRETRAIN_FINAL_CKPT = CACHE_DIR / "pretrain_final.pt"
|
| 59 |
-
FAILED_CKPT = CACHE_DIR / "latest_failed.pt" # crash/FAIL path — never overwrites good
|
| 60 |
-
BEST_CKPT = CACHE_DIR / "best_bpb.pt" # lowest val_bpb seen
|
| 61 |
-
CKPT_INTERVAL = int(os.environ.get("HYDRA_CKPT_INTERVAL", "250"))
|
| 62 |
-
CKPT_ROTATIONS = int(os.environ.get("HYDRA_CKPT_ROTATIONS", "3")) # how many .N backups to keep
|
| 63 |
-
RESUME_CKPT = os.environ.get("HYDRA_RESUME_CKPT", str(LATEST_CKPT))
|
| 64 |
-
|
| 65 |
-
# MDLM (Masked Diffusion LM) Rao-Blackwellized ELBO loss path.
|
| 66 |
-
# HYDRA_USE_MDLM=1 : switch training loss from AR sampled-softmax CE
|
| 67 |
-
# to MDLM RB weighted CE (arXiv:2406.07524).
|
| 68 |
-
# HYDRA_MDLM_MASK_ID=N : token id used for the MASK sentinel (default:
|
| 69 |
-
# last valid id, vocab_size - 1). Ensure this id
|
| 70 |
-
# never appears in training targets — typical
|
| 71 |
-
# practice is to reserve it.
|
| 72 |
-
# HYDRA_MDLM_SCHEDULE=loglinear|linear : noise schedule (default loglinear).
|
| 73 |
-
# When enabled, the per-step flow is:
|
| 74 |
-
# 1. mdlm_masked_forward_process(y) -> (x_noised, mask_positions, weights)
|
| 75 |
-
# 2. logits = model(x_noised) (no targets -> full V logits)
|
| 76 |
-
# 3. loss = mdlm_rb_loss(logits, y, mask_positions, weights)
|
| 77 |
-
# Sampled-softmax is bypassed in this path because the RB ELBO needs
|
| 78 |
-
# full-vocab logits on masked positions.
|
| 79 |
-
USE_MDLM = os.environ.get("HYDRA_USE_MDLM", "0") == "1"
|
| 80 |
-
MDLM_MASK_ID = int(os.environ.get("HYDRA_MDLM_MASK_ID", "-1")) # -1 => default to vocab_size-1 at runtime
|
| 81 |
-
MDLM_SCHEDULE = os.environ.get("HYDRA_MDLM_SCHEDULE", "loglinear")
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
# ---------------------------------------------------------------------------
|
| 85 |
-
# Schedules
|
| 86 |
-
# ---------------------------------------------------------------------------
|
| 87 |
-
|
| 88 |
-
def get_lr_multiplier(progress: float) -> float:
|
| 89 |
-
if progress < WARMUP_RATIO:
|
| 90 |
-
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
| 91 |
-
decay_progress = (progress - WARMUP_RATIO) / (1.0 - WARMUP_RATIO)
|
| 92 |
-
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (1 + math.cos(math.pi * decay_progress))
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def get_muon_momentum(step: int) -> float:
|
| 96 |
-
frac = min(step / 300, 1)
|
| 97 |
-
return (1 - frac) * 0.85 + frac * 0.95
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def get_weight_decay(progress: float) -> float:
|
| 101 |
-
return WEIGHT_DECAY * (1 - progress)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
_CKPT_WORKER_THREAD: threading.Thread | None = None
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def _ckpt_snapshot_state_dicts(
|
| 108 |
-
model: PostSemClawModel,
|
| 109 |
-
optimizer: torch.optim.Optimizer,
|
| 110 |
-
) -> tuple[dict, dict]:
|
| 111 |
-
"""Detach + CPU-clone every tensor so a bg thread can serialize safely
|
| 112 |
-
while the main loop keeps mutating live weights/optimizer state."""
|
| 113 |
-
msd = {k: (v.detach().to("cpu", copy=True) if torch.is_tensor(v) else v)
|
| 114 |
-
for k, v in model.state_dict().items()}
|
| 115 |
-
# optimizer.state_dict() is a nested dict; walk it.
|
| 116 |
-
osd_raw = optimizer.state_dict()
|
| 117 |
-
|
| 118 |
-
def _to_cpu(obj):
|
| 119 |
-
if torch.is_tensor(obj):
|
| 120 |
-
return obj.detach().to("cpu", copy=True)
|
| 121 |
-
if isinstance(obj, dict):
|
| 122 |
-
return {k: _to_cpu(v) for k, v in obj.items()}
|
| 123 |
-
if isinstance(obj, list):
|
| 124 |
-
return [_to_cpu(v) for v in obj]
|
| 125 |
-
if isinstance(obj, tuple):
|
| 126 |
-
return tuple(_to_cpu(v) for v in obj)
|
| 127 |
-
return obj
|
| 128 |
-
|
| 129 |
-
osd = _to_cpu(osd_raw)
|
| 130 |
-
return msd, osd
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def save_ckpt(
|
| 134 |
-
model: PostSemClawModel,
|
| 135 |
-
optimizer: torch.optim.Optimizer,
|
| 136 |
-
config: PostSemClawConfig,
|
| 137 |
-
step: int,
|
| 138 |
-
total_training_time: float,
|
| 139 |
-
smooth_train_loss: float,
|
| 140 |
-
bpt_ema: float,
|
| 141 |
-
epoch: int,
|
| 142 |
-
path: Path,
|
| 143 |
-
*,
|
| 144 |
-
val_bpb: float | None = None,
|
| 145 |
-
blocking: bool = False,
|
| 146 |
-
) -> None:
|
| 147 |
-
"""Save a training checkpoint.
|
| 148 |
-
|
| 149 |
-
Default behavior is async: the GPU→CPU state_dict clone runs on the main
|
| 150 |
-
thread (unavoidable; needs to happen before the next optimizer.step that
|
| 151 |
-
mutates live weights), then `torch.save` is dispatched to a daemon
|
| 152 |
-
worker thread. The next call joins any still-running prior save so only
|
| 153 |
-
one disk write is in flight.
|
| 154 |
-
|
| 155 |
-
`blocking=True` restores the original synchronous behavior — used for
|
| 156 |
-
end-of-training saves where correctness on process exit matters.
|
| 157 |
-
"""
|
| 158 |
-
global _CKPT_WORKER_THREAD
|
| 159 |
-
try:
|
| 160 |
-
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 161 |
-
msd, osd = _ckpt_snapshot_state_dicts(model, optimizer)
|
| 162 |
-
# asdict() recursively converts dataclass fields to a dict and
|
| 163 |
-
# renders tuples as lists. hyena_layers therefore round-trips as a
|
| 164 |
-
# JSON-safe list; config_from_dict normalizes it back to a tuple.
|
| 165 |
-
payload = {
|
| 166 |
-
"model_state_dict": msd,
|
| 167 |
-
"optimizer_state_dict": osd,
|
| 168 |
-
"config": asdict(config),
|
| 169 |
-
"step": step,
|
| 170 |
-
"epoch": epoch,
|
| 171 |
-
"train_seconds": total_training_time,
|
| 172 |
-
"smoothed_loss": smooth_train_loss,
|
| 173 |
-
"bpt_ema": bpt_ema,
|
| 174 |
-
"val_bpb": val_bpb,
|
| 175 |
-
}
|
| 176 |
-
path_str = str(path)
|
| 177 |
-
|
| 178 |
-
def _rotate(p: str) -> None:
|
| 179 |
-
"""Keep up to CKPT_ROTATIONS previous versions as p.1, p.2, ..."""
|
| 180 |
-
if CKPT_ROTATIONS <= 0:
|
| 181 |
-
return
|
| 182 |
-
try:
|
| 183 |
-
# Walk from oldest to newest so we don't clobber newer with older.
|
| 184 |
-
for i in range(CKPT_ROTATIONS, 0, -1):
|
| 185 |
-
src = f"{p}.{i-1}" if i > 1 else p
|
| 186 |
-
dst = f"{p}.{i}"
|
| 187 |
-
if os.path.exists(src):
|
| 188 |
-
os.replace(src, dst)
|
| 189 |
-
except Exception as e:
|
| 190 |
-
# Rotation is best-effort; never block a save on it.
|
| 191 |
-
print(f"[ckpt] rotate warn {p}: {type(e).__name__}: {e}", flush=True)
|
| 192 |
-
|
| 193 |
-
def _write():
|
| 194 |
-
try:
|
| 195 |
-
_rotate(path_str)
|
| 196 |
-
tmp = path_str + ".tmp"
|
| 197 |
-
torch.save(payload, tmp)
|
| 198 |
-
os.replace(tmp, path_str)
|
| 199 |
-
print(f"[ckpt] saved {path_str} (step={step})", flush=True)
|
| 200 |
-
except Exception as e:
|
| 201 |
-
print(f"[ckpt] SAVE FAILED {path_str}: {type(e).__name__}: {e}", flush=True)
|
| 202 |
-
|
| 203 |
-
if blocking:
|
| 204 |
-
_write()
|
| 205 |
-
return
|
| 206 |
-
|
| 207 |
-
# Join previous writer so at most one torch.save runs at a time.
|
| 208 |
-
if _CKPT_WORKER_THREAD is not None and _CKPT_WORKER_THREAD.is_alive():
|
| 209 |
-
_CKPT_WORKER_THREAD.join()
|
| 210 |
-
_CKPT_WORKER_THREAD = threading.Thread(
|
| 211 |
-
target=_write, daemon=True, name=f"ckpt-save-{step}"
|
| 212 |
-
)
|
| 213 |
-
_CKPT_WORKER_THREAD.start()
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
""
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
#
|
| 327 |
-
#
|
| 328 |
-
#
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 339 |
_p_nemo.ensure_tokenizer()
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
)
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
model.
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
#
|
| 396 |
-
#
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
#
|
| 402 |
-
#
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
#
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
)
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
#
|
| 422 |
-
#
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
f"
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
print(f"
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
#
|
| 440 |
-
#
|
| 441 |
-
#
|
| 442 |
-
#
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
#
|
| 447 |
-
#
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
#
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
#
|
| 462 |
-
#
|
| 463 |
-
#
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
#
|
| 478 |
-
#
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
loss
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
group[
|
| 510 |
-
|
| 511 |
-
group["
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
#
|
| 526 |
-
#
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
f"
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
#
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
#
|
| 542 |
-
#
|
| 543 |
-
|
| 544 |
-
_last_sdr
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
#
|
| 562 |
-
#
|
| 563 |
-
#
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
_hestia_interval =
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
mdl.hestia.
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
)
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
model.hestia.
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
#
|
| 597 |
-
#
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
)
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
f"
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
#
|
| 639 |
-
#
|
| 640 |
-
#
|
| 641 |
-
#
|
| 642 |
-
#
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
#
|
| 657 |
-
#
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
#
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
f"
|
| 669 |
-
f"
|
| 670 |
-
f"
|
| 671 |
-
f"
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
gc.
|
| 678 |
-
gc.
|
| 679 |
-
|
| 680 |
-
#
|
| 681 |
-
#
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
#
|
| 698 |
-
#
|
| 699 |
-
|
| 700 |
-
mid_val_interval
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
#
|
| 705 |
-
#
|
| 706 |
-
#
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
#
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
f"
|
| 734 |
-
f"
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
htm_proj_g
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
#
|
| 761 |
-
#
|
| 762 |
-
#
|
| 763 |
-
#
|
| 764 |
-
#
|
| 765 |
-
#
|
| 766 |
-
#
|
| 767 |
-
#
|
| 768 |
-
#
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
#
|
| 783 |
-
#
|
| 784 |
-
#
|
| 785 |
-
#
|
| 786 |
-
#
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
#
|
| 797 |
-
#
|
| 798 |
-
#
|
| 799 |
-
optimizer
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
#
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
print(
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
'
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HYDRA training entry: setup, train loop, eval, summary.
|
| 2 |
+
|
| 3 |
+
Extracted from the monolithic train.py (W1 modularization). Semantics
|
| 4 |
+
preserved. Public entrypoint: `main()`.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import threading
|
| 15 |
+
import time
|
| 16 |
+
from dataclasses import asdict
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
# Line-buffered stdout so `python -u train.py | tee run.log | grep step` is
|
| 22 |
+
# live (no \r overwrite, no 4k block-buffered pipe stalls). Safe on Python
|
| 23 |
+
# 3.7+ where io.TextIOWrapper.reconfigure exists.
|
| 24 |
+
try:
|
| 25 |
+
sys.stdout.reconfigure(line_buffering=True) # type: ignore[attr-defined]
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
from hydra.config import (
|
| 30 |
+
ADAM_BETAS, CURRICULUM_SHORT_SEQ_LEN, CURRICULUM_SHORT_STEPS,
|
| 31 |
+
D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMA_DECAY, EMBEDDING_LR,
|
| 32 |
+
ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
|
| 33 |
+
FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
|
| 34 |
+
N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
|
| 35 |
+
UNEMBEDDING_LR, USE_EMA, WARMUP_RATIO, WEIGHT_DECAY,
|
| 36 |
+
)
|
| 37 |
+
from hydra.diffusion_loss import mdlm_masked_forward_process, mdlm_rb_loss
|
| 38 |
+
from hydra.eval import run_factual_english, run_factual_probes
|
| 39 |
+
from hydra.model import PostSemClawModel
|
| 40 |
+
|
| 41 |
+
import prepare as _prepare_mod
|
| 42 |
+
from prepare import MAX_SEQ_LEN, TIME_BUDGET as _TIME_BUDGET, Tokenizer, evaluate_bpb as _evaluate_bpb_shards, get_token_bytes, make_dataloader as _make_dataloader_shards
|
| 43 |
+
|
| 44 |
+
# Streaming Nemotron path (Super3 recipe). Opt-in via HYDRA_USE_NEMOTRON=1.
|
| 45 |
+
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 46 |
+
import prepare_nemotron as _p_nemo
|
| 47 |
+
make_dataloader = _p_nemo.make_dataloader
|
| 48 |
+
evaluate_bpb = _p_nemo.evaluate_bpb
|
| 49 |
+
else:
|
| 50 |
+
make_dataloader = _make_dataloader_shards
|
| 51 |
+
evaluate_bpb = _evaluate_bpb_shards
|
| 52 |
+
|
| 53 |
+
TIME_BUDGET = int(os.environ.get("HYDRA_TIME_BUDGET", str(_TIME_BUDGET)))
|
| 54 |
+
_prepare_mod.TIME_BUDGET = TIME_BUDGET # sync for evaluate_bpb
|
| 55 |
+
|
| 56 |
+
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
|
| 57 |
+
LATEST_CKPT = CACHE_DIR / "latest.pt"
|
| 58 |
+
PRETRAIN_FINAL_CKPT = CACHE_DIR / "pretrain_final.pt"
|
| 59 |
+
FAILED_CKPT = CACHE_DIR / "latest_failed.pt" # crash/FAIL path — never overwrites good
|
| 60 |
+
BEST_CKPT = CACHE_DIR / "best_bpb.pt" # lowest val_bpb seen
|
| 61 |
+
CKPT_INTERVAL = int(os.environ.get("HYDRA_CKPT_INTERVAL", "250"))
|
| 62 |
+
CKPT_ROTATIONS = int(os.environ.get("HYDRA_CKPT_ROTATIONS", "3")) # how many .N backups to keep
|
| 63 |
+
RESUME_CKPT = os.environ.get("HYDRA_RESUME_CKPT", str(LATEST_CKPT))
|
| 64 |
+
|
| 65 |
+
# MDLM (Masked Diffusion LM) Rao-Blackwellized ELBO loss path.
|
| 66 |
+
# HYDRA_USE_MDLM=1 : switch training loss from AR sampled-softmax CE
|
| 67 |
+
# to MDLM RB weighted CE (arXiv:2406.07524).
|
| 68 |
+
# HYDRA_MDLM_MASK_ID=N : token id used for the MASK sentinel (default:
|
| 69 |
+
# last valid id, vocab_size - 1). Ensure this id
|
| 70 |
+
# never appears in training targets — typical
|
| 71 |
+
# practice is to reserve it.
|
| 72 |
+
# HYDRA_MDLM_SCHEDULE=loglinear|linear : noise schedule (default loglinear).
|
| 73 |
+
# When enabled, the per-step flow is:
|
| 74 |
+
# 1. mdlm_masked_forward_process(y) -> (x_noised, mask_positions, weights)
|
| 75 |
+
# 2. logits = model(x_noised) (no targets -> full V logits)
|
| 76 |
+
# 3. loss = mdlm_rb_loss(logits, y, mask_positions, weights)
|
| 77 |
+
# Sampled-softmax is bypassed in this path because the RB ELBO needs
|
| 78 |
+
# full-vocab logits on masked positions.
|
| 79 |
+
USE_MDLM = os.environ.get("HYDRA_USE_MDLM", "0") == "1"
|
| 80 |
+
MDLM_MASK_ID = int(os.environ.get("HYDRA_MDLM_MASK_ID", "-1")) # -1 => default to vocab_size-1 at runtime
|
| 81 |
+
MDLM_SCHEDULE = os.environ.get("HYDRA_MDLM_SCHEDULE", "loglinear")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Schedules
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def get_lr_multiplier(progress: float) -> float:
|
| 89 |
+
if progress < WARMUP_RATIO:
|
| 90 |
+
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
| 91 |
+
decay_progress = (progress - WARMUP_RATIO) / (1.0 - WARMUP_RATIO)
|
| 92 |
+
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (1 + math.cos(math.pi * decay_progress))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_muon_momentum(step: int) -> float:
|
| 96 |
+
frac = min(step / 300, 1)
|
| 97 |
+
return (1 - frac) * 0.85 + frac * 0.95
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_weight_decay(progress: float) -> float:
|
| 101 |
+
return WEIGHT_DECAY * (1 - progress)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
_CKPT_WORKER_THREAD: threading.Thread | None = None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _ckpt_snapshot_state_dicts(
|
| 108 |
+
model: PostSemClawModel,
|
| 109 |
+
optimizer: torch.optim.Optimizer,
|
| 110 |
+
) -> tuple[dict, dict]:
|
| 111 |
+
"""Detach + CPU-clone every tensor so a bg thread can serialize safely
|
| 112 |
+
while the main loop keeps mutating live weights/optimizer state."""
|
| 113 |
+
msd = {k: (v.detach().to("cpu", copy=True) if torch.is_tensor(v) else v)
|
| 114 |
+
for k, v in model.state_dict().items()}
|
| 115 |
+
# optimizer.state_dict() is a nested dict; walk it.
|
| 116 |
+
osd_raw = optimizer.state_dict()
|
| 117 |
+
|
| 118 |
+
def _to_cpu(obj):
|
| 119 |
+
if torch.is_tensor(obj):
|
| 120 |
+
return obj.detach().to("cpu", copy=True)
|
| 121 |
+
if isinstance(obj, dict):
|
| 122 |
+
return {k: _to_cpu(v) for k, v in obj.items()}
|
| 123 |
+
if isinstance(obj, list):
|
| 124 |
+
return [_to_cpu(v) for v in obj]
|
| 125 |
+
if isinstance(obj, tuple):
|
| 126 |
+
return tuple(_to_cpu(v) for v in obj)
|
| 127 |
+
return obj
|
| 128 |
+
|
| 129 |
+
osd = _to_cpu(osd_raw)
|
| 130 |
+
return msd, osd
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def save_ckpt(
|
| 134 |
+
model: PostSemClawModel,
|
| 135 |
+
optimizer: torch.optim.Optimizer,
|
| 136 |
+
config: PostSemClawConfig,
|
| 137 |
+
step: int,
|
| 138 |
+
total_training_time: float,
|
| 139 |
+
smooth_train_loss: float,
|
| 140 |
+
bpt_ema: float,
|
| 141 |
+
epoch: int,
|
| 142 |
+
path: Path,
|
| 143 |
+
*,
|
| 144 |
+
val_bpb: float | None = None,
|
| 145 |
+
blocking: bool = False,
|
| 146 |
+
) -> None:
|
| 147 |
+
"""Save a training checkpoint.
|
| 148 |
+
|
| 149 |
+
Default behavior is async: the GPU→CPU state_dict clone runs on the main
|
| 150 |
+
thread (unavoidable; needs to happen before the next optimizer.step that
|
| 151 |
+
mutates live weights), then `torch.save` is dispatched to a daemon
|
| 152 |
+
worker thread. The next call joins any still-running prior save so only
|
| 153 |
+
one disk write is in flight.
|
| 154 |
+
|
| 155 |
+
`blocking=True` restores the original synchronous behavior — used for
|
| 156 |
+
end-of-training saves where correctness on process exit matters.
|
| 157 |
+
"""
|
| 158 |
+
global _CKPT_WORKER_THREAD
|
| 159 |
+
try:
|
| 160 |
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 161 |
+
msd, osd = _ckpt_snapshot_state_dicts(model, optimizer)
|
| 162 |
+
# asdict() recursively converts dataclass fields to a dict and
|
| 163 |
+
# renders tuples as lists. hyena_layers therefore round-trips as a
|
| 164 |
+
# JSON-safe list; config_from_dict normalizes it back to a tuple.
|
| 165 |
+
payload = {
|
| 166 |
+
"model_state_dict": msd,
|
| 167 |
+
"optimizer_state_dict": osd,
|
| 168 |
+
"config": asdict(config),
|
| 169 |
+
"step": step,
|
| 170 |
+
"epoch": epoch,
|
| 171 |
+
"train_seconds": total_training_time,
|
| 172 |
+
"smoothed_loss": smooth_train_loss,
|
| 173 |
+
"bpt_ema": bpt_ema,
|
| 174 |
+
"val_bpb": val_bpb,
|
| 175 |
+
}
|
| 176 |
+
path_str = str(path)
|
| 177 |
+
|
| 178 |
+
def _rotate(p: str) -> None:
|
| 179 |
+
"""Keep up to CKPT_ROTATIONS previous versions as p.1, p.2, ..."""
|
| 180 |
+
if CKPT_ROTATIONS <= 0:
|
| 181 |
+
return
|
| 182 |
+
try:
|
| 183 |
+
# Walk from oldest to newest so we don't clobber newer with older.
|
| 184 |
+
for i in range(CKPT_ROTATIONS, 0, -1):
|
| 185 |
+
src = f"{p}.{i-1}" if i > 1 else p
|
| 186 |
+
dst = f"{p}.{i}"
|
| 187 |
+
if os.path.exists(src):
|
| 188 |
+
os.replace(src, dst)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
# Rotation is best-effort; never block a save on it.
|
| 191 |
+
print(f"[ckpt] rotate warn {p}: {type(e).__name__}: {e}", flush=True)
|
| 192 |
+
|
| 193 |
+
def _write():
|
| 194 |
+
try:
|
| 195 |
+
_rotate(path_str)
|
| 196 |
+
tmp = path_str + ".tmp"
|
| 197 |
+
torch.save(payload, tmp)
|
| 198 |
+
os.replace(tmp, path_str)
|
| 199 |
+
print(f"[ckpt] saved {path_str} (step={step})", flush=True)
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"[ckpt] SAVE FAILED {path_str}: {type(e).__name__}: {e}", flush=True)
|
| 202 |
+
|
| 203 |
+
if blocking:
|
| 204 |
+
_write()
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
# Join previous writer so at most one torch.save runs at a time.
|
| 208 |
+
if _CKPT_WORKER_THREAD is not None and _CKPT_WORKER_THREAD.is_alive():
|
| 209 |
+
_CKPT_WORKER_THREAD.join()
|
| 210 |
+
_CKPT_WORKER_THREAD = threading.Thread(
|
| 211 |
+
target=_write, daemon=True, name=f"ckpt-save-{step}"
|
| 212 |
+
)
|
| 213 |
+
_CKPT_WORKER_THREAD.start()
|
| 214 |
+
# Non-default checkpoint paths are usually tests or one-off utilities that
|
| 215 |
+
# expect save_ckpt() to be durable when it returns. Keep the hot training
|
| 216 |
+
# path async for CACHE_DIR checkpoints, but make explicit custom paths
|
| 217 |
+
# deterministic.
|
| 218 |
+
if path.parent.resolve() != CACHE_DIR.resolve():
|
| 219 |
+
_CKPT_WORKER_THREAD.join()
|
| 220 |
+
except Exception as e:
|
| 221 |
+
print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
|
| 225 |
+
"""Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
|
| 226 |
+
|
| 227 |
+
Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
|
| 228 |
+
older checkpoints, and list-ified tuples are coerced back to tuples so
|
| 229 |
+
the dataclass keeps its declared types.
|
| 230 |
+
|
| 231 |
+
This is the ckpt-safe inverse of `asdict(config)` used by save_ckpt and
|
| 232 |
+
guarantees that a resume path can rebuild the exact same model topology
|
| 233 |
+
(Mamba3 vs HyenaBlock per layer) regardless of env-var state at resume.
|
| 234 |
+
"""
|
| 235 |
+
# Only keep keys that are actually declared on PostSemClawConfig — extra
|
| 236 |
+
# keys in older/newer checkpoints must not crash construction.
|
| 237 |
+
field_names = {f.name for f in PostSemClawConfig.__dataclass_fields__.values()}
|
| 238 |
+
filtered = {k: v for k, v in cfg_dict.items() if k in field_names}
|
| 239 |
+
# asdict renders tuple[int,...] as list[int]; coerce back so the model
|
| 240 |
+
# builder sees the declared type.
|
| 241 |
+
if "hyena_layers" in filtered and filtered["hyena_layers"] is not None:
|
| 242 |
+
filtered["hyena_layers"] = tuple(sorted(int(x) for x in filtered["hyena_layers"]))
|
| 243 |
+
return PostSemClawConfig(**filtered)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _try_load_ckpt(path: Path, model, optimizer, device):
|
| 247 |
+
"""Attempt to load a single ckpt. Returns the tuple on success, None on any failure."""
|
| 248 |
+
if not path.exists():
|
| 249 |
+
return None
|
| 250 |
+
ckpt = torch.load(str(path), map_location=device, weights_only=False)
|
| 251 |
+
state = ckpt.get("model_state_dict", ckpt)
|
| 252 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 253 |
+
if missing:
|
| 254 |
+
print(f"[ckpt] {path.name} missing={len(missing)}", flush=True)
|
| 255 |
+
if unexpected:
|
| 256 |
+
print(f"[ckpt] {path.name} unexpected={len(unexpected)}", flush=True)
|
| 257 |
+
optimizer_state = ckpt.get("optimizer_state_dict")
|
| 258 |
+
if optimizer_state is not None:
|
| 259 |
+
try:
|
| 260 |
+
optimizer.load_state_dict(optimizer_state)
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"[ckpt] optimizer restore failed from {path.name}: {type(e).__name__}: {e}", flush=True)
|
| 263 |
+
step = int(ckpt.get("step", 0))
|
| 264 |
+
total_training_time = float(ckpt.get("train_seconds", 0.0))
|
| 265 |
+
smooth_train_loss = float(ckpt.get("smoothed_loss", 0.0))
|
| 266 |
+
bpt_ema = float(ckpt.get("bpt_ema", 0.0))
|
| 267 |
+
epoch = int(ckpt.get("epoch", 0))
|
| 268 |
+
print(
|
| 269 |
+
f"[ckpt] resumed {path} step={step} train_seconds={total_training_time:.1f}",
|
| 270 |
+
flush=True,
|
| 271 |
+
)
|
| 272 |
+
# Warn if resuming a schedule-exhausted ckpt — user is probably warm-starting.
|
| 273 |
+
budget = float(os.environ.get("HYDRA_TIME_BUDGET", "0") or 0)
|
| 274 |
+
if budget and total_training_time >= 0.99 * budget:
|
| 275 |
+
print(
|
| 276 |
+
f"[ckpt] WARNING: resumed ckpt used {total_training_time:.0f}s of {budget:.0f}s "
|
| 277 |
+
f"budget. LR schedule is essentially exhausted. "
|
| 278 |
+
f"Set HYDRA_WARMSTART=1 to reset optimizer + scheduler and keep only weights.",
|
| 279 |
+
flush=True,
|
| 280 |
+
)
|
| 281 |
+
return step, total_training_time, smooth_train_loss, bpt_ema, epoch
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def maybe_resume_ckpt(
|
| 285 |
+
model: PostSemClawModel,
|
| 286 |
+
optimizer: torch.optim.Optimizer,
|
| 287 |
+
device: torch.device,
|
| 288 |
+
) -> tuple[int, float, float, float, int]:
|
| 289 |
+
if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
|
| 290 |
+
print("[ckpt] resume disabled; starting fresh", flush=True)
|
| 291 |
+
return 0, 0.0, 0.0, 0.0, 0
|
| 292 |
+
|
| 293 |
+
resume_path = Path(os.path.expanduser(RESUME_CKPT))
|
| 294 |
+
# Try the primary path, then rotated backups. This is crucial because a
|
| 295 |
+
# partial / killed torch.save on the primary path would leave a corrupt
|
| 296 |
+
# file. If that fails we fall back to latest.pt.1, .2, .3 automatically.
|
| 297 |
+
candidates: list[Path] = [resume_path]
|
| 298 |
+
for i in range(1, CKPT_ROTATIONS + 1):
|
| 299 |
+
candidates.append(Path(str(resume_path) + f".{i}"))
|
| 300 |
+
|
| 301 |
+
for cand in candidates:
|
| 302 |
+
if not cand.exists():
|
| 303 |
+
continue
|
| 304 |
+
try:
|
| 305 |
+
result = _try_load_ckpt(cand, model, optimizer, device)
|
| 306 |
+
if result is not None:
|
| 307 |
+
if cand != resume_path:
|
| 308 |
+
print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
|
| 309 |
+
return result
|
| 310 |
+
except Exception as e:
|
| 311 |
+
print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
|
| 315 |
+
return 0, 0.0, 0.0, 0.0, 0
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# ---------------------------------------------------------------------------
|
| 319 |
+
# Main entry
|
| 320 |
+
# ---------------------------------------------------------------------------
|
| 321 |
+
|
| 322 |
+
def main() -> None:
|
| 323 |
+
t_start = time.time()
|
| 324 |
+
torch.manual_seed(SEED)
|
| 325 |
+
torch.cuda.manual_seed(SEED)
|
| 326 |
+
# Precision / kernel-selection knobs for peak throughput on Ampere.
|
| 327 |
+
# - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
|
| 328 |
+
# - allow_tf32 : explicit for both matmul + cudnn paths
|
| 329 |
+
# - cudnn.benchmark : env-gated (HYDRA_CUDNN_BENCHMARK, default OFF).
|
| 330 |
+
# TRUE can lock in a locally-better-but-globally-slower algorithm
|
| 331 |
+
# after the autotune phase ends, causing tps to degrade 15-20%
|
| 332 |
+
# over the first ~100 steps. Observed 2026-04-22 and confirmed by
|
| 333 |
+
# differential profiling. Default is now FALSE; set =1 only if you
|
| 334 |
+
# see a specific workload where benchmark helps sustained tps.
|
| 335 |
+
torch.set_float32_matmul_precision("high")
|
| 336 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 337 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 338 |
+
torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
|
| 339 |
+
device = torch.device("cuda")
|
| 340 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 341 |
+
|
| 342 |
+
# Streaming path skips prepare.py (which normally trains the tokenizer
|
| 343 |
+
# and builds the retina), so we must materialize both before model init.
|
| 344 |
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 345 |
_p_nemo.ensure_tokenizer()
|
| 346 |
+
# Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
|
| 347 |
+
# returns in seconds; otherwise build_retina streams Nemotron docs to
|
| 348 |
+
# compute cooccurrence + train SOM, then uploads back to the cache.
|
| 349 |
+
import subsystems.sdr_retina as _sdr_retina
|
| 350 |
+
_sdr_retina.build_retina()
|
| 351 |
+
tokenizer = Tokenizer.from_directory()
|
| 352 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 353 |
+
print(f"Vocab size: {vocab_size:,}")
|
| 354 |
+
|
| 355 |
+
config = PostSemClawConfig(
|
| 356 |
+
sequence_len=MAX_SEQ_LEN,
|
| 357 |
+
vocab_size=vocab_size,
|
| 358 |
+
n_layer=N_LAYER,
|
| 359 |
+
d_model=D_MODEL,
|
| 360 |
+
d_state=D_STATE,
|
| 361 |
+
headdim=HEADDIM,
|
| 362 |
+
n_heads=N_HEADS,
|
| 363 |
+
expand=EXPAND,
|
| 364 |
+
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 365 |
+
engram_key_dim=ENGRAM_KEY_DIM,
|
| 366 |
+
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 367 |
+
)
|
| 368 |
+
print(f"Model config: {asdict(config)}")
|
| 369 |
+
|
| 370 |
+
with torch.device("meta"):
|
| 371 |
+
model = PostSemClawModel(config)
|
| 372 |
+
model.to_empty(device=device)
|
| 373 |
+
model.init_weights()
|
| 374 |
+
|
| 375 |
+
param_counts = model.num_scaling_params()
|
| 376 |
+
print("Parameter counts:")
|
| 377 |
+
for key, value in param_counts.items():
|
| 378 |
+
print(f" {key:24s}: {value:,}")
|
| 379 |
+
num_params = param_counts['total']
|
| 380 |
+
num_flops_per_token = model.estimate_flops()
|
| 381 |
+
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
| 382 |
+
|
| 383 |
+
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
|
| 384 |
+
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
|
| 385 |
+
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
|
| 386 |
+
|
| 387 |
+
optimizer = model.setup_optimizer(
|
| 388 |
+
unembedding_lr=UNEMBEDDING_LR,
|
| 389 |
+
embedding_lr=EMBEDDING_LR,
|
| 390 |
+
scalar_lr=SCALAR_LR,
|
| 391 |
+
adam_betas=ADAM_BETAS,
|
| 392 |
+
matrix_lr=MATRIX_LR,
|
| 393 |
+
weight_decay=WEIGHT_DECAY,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch = maybe_resume_ckpt(
|
| 397 |
+
model, optimizer, device,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Learnability #4: inform the model of the BOS token id so it can mask
|
| 401 |
+
# doc-separator positions in packed sequences. Always set (the mask only
|
| 402 |
+
# fires when HYDRA_DOC_SEP_MASK=1 is also on).
|
| 403 |
+
if hasattr(model, 'set_bos_token_id'):
|
| 404 |
+
model.set_bos_token_id(tokenizer.get_bos_token_id())
|
| 405 |
+
|
| 406 |
+
# Learnability #2: EMA shadow copy of weights. AveragedModel clones every
|
| 407 |
+
# parameter; we update it after every optimizer step and save it at the
|
| 408 |
+
# end alongside the raw checkpoint. Defaults OFF.
|
| 409 |
+
ema_model = None
|
| 410 |
+
if USE_EMA:
|
| 411 |
+
try:
|
| 412 |
+
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
|
| 413 |
+
# decay=EMA_DECAY; avg_fn uses get_ema_multi_avg_fn for numerical
|
| 414 |
+
# stability across bf16/fp32 mixed parameter groups.
|
| 415 |
+
ema_model = AveragedModel(
|
| 416 |
+
model,
|
| 417 |
+
multi_avg_fn=get_ema_multi_avg_fn(EMA_DECAY),
|
| 418 |
+
)
|
| 419 |
+
print(f"[EMA] enabled with decay={EMA_DECAY}")
|
| 420 |
+
except Exception as _e:
|
| 421 |
+
print(f"[EMA] disabled — AveragedModel init failed: {_e}")
|
| 422 |
+
ema_model = None
|
| 423 |
+
|
| 424 |
+
print("torch.compile: Muon step compiled; AdamW uses torch._fused_adamw_ (model blocks use native CUDA kernels)")
|
| 425 |
+
|
| 426 |
+
# Learnability #7: curriculum short-then-long. If enabled, build the
|
| 427 |
+
# initial dataloader at the short seq_len; we swap to full MAX_SEQ_LEN
|
| 428 |
+
# after CURRICULUM_SHORT_STEPS optimizer steps (see loop below).
|
| 429 |
+
_curriculum_active = CURRICULUM_SHORT_STEPS > 0 and CURRICULUM_SHORT_SEQ_LEN < MAX_SEQ_LEN
|
| 430 |
+
_current_seq_len = CURRICULUM_SHORT_SEQ_LEN if _curriculum_active else MAX_SEQ_LEN
|
| 431 |
+
if _curriculum_active:
|
| 432 |
+
print(
|
| 433 |
+
f"[CURRICULUM] starting at T={_current_seq_len} for "
|
| 434 |
+
f"{CURRICULUM_SHORT_STEPS} steps, then switching to T={MAX_SEQ_LEN}"
|
| 435 |
+
)
|
| 436 |
+
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
|
| 437 |
+
x, y, epoch = next(train_loader) # prefetch first batch
|
| 438 |
+
if resume_epoch > 0:
|
| 439 |
+
epoch = max(epoch, resume_epoch)
|
| 440 |
+
|
| 441 |
+
print(f"Time budget: {TIME_BUDGET}s")
|
| 442 |
+
print(f"Gradient accumulation steps: {grad_accum_steps}")
|
| 443 |
+
|
| 444 |
+
# Token→byte LUT for bits-per-byte computation. evaluate_bpb in prepare.py
|
| 445 |
+
# uses total_nats / (ln(2) * total_bytes); our live metric needs to match.
|
| 446 |
+
# Without this, `bpb = loss/ln(2)` is actually bits-per-TOKEN, which at
|
| 447 |
+
# vocab=8192 scales by ~4 and makes live train bpb non-comparable with
|
| 448 |
+
# val_bpb (champion 1.279 bpb vs train printing "8.04").
|
| 449 |
+
token_bytes = get_token_bytes(device=device)
|
| 450 |
+
|
| 451 |
+
# -----------------------------------------------------------------------
|
| 452 |
+
# Training loop
|
| 453 |
+
# -----------------------------------------------------------------------
|
| 454 |
+
|
| 455 |
+
t_start_training = time.time()
|
| 456 |
+
|
| 457 |
+
# Async postprocessing — run SOM + Hestia on background threads so
|
| 458 |
+
# the GPU doesn't idle during their CPU-bound work.
|
| 459 |
+
_ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
|
| 460 |
+
_som_thread: threading.Thread | None = None
|
| 461 |
+
_hestia_thread: threading.Thread | None = None
|
| 462 |
+
_hestia_stream: torch.cuda.Stream | None = (
|
| 463 |
+
torch.cuda.Stream() if _ASYNC_POSTPROCESS else None
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# HYDRA_PROFILE_STEPS=N prints a per-phase cpu/gpu time breakdown for the
|
| 467 |
+
# first N steps (and every 100th step thereafter if N<0). Zero overhead
|
| 468 |
+
# when disabled. Used to find what's eating CPU budget when GPU should
|
| 469 |
+
# be the bottleneck.
|
| 470 |
+
_profile_steps = int(os.environ.get("HYDRA_PROFILE_STEPS", "0"))
|
| 471 |
+
|
| 472 |
+
while True:
|
| 473 |
+
torch.cuda.synchronize()
|
| 474 |
+
t0 = time.time()
|
| 475 |
+
_prof = _profile_steps and (step < _profile_steps or (_profile_steps < 0 and step % 100 == 0))
|
| 476 |
+
_gpu_ms = 0.0
|
| 477 |
+
_data_ms = 0.0
|
| 478 |
+
for micro_step in range(grad_accum_steps):
|
| 479 |
+
if _prof:
|
| 480 |
+
torch.cuda.synchronize(); _t_micro = time.time()
|
| 481 |
+
if USE_MDLM:
|
| 482 |
+
# MDLM path: corrupt y -> x_noised, run model to get full-V logits,
|
| 483 |
+
# compute RB weighted CE on masked positions. x (original input) is
|
| 484 |
+
# unused in this path — the model only sees the noised version of y.
|
| 485 |
+
_mask_id = MDLM_MASK_ID if MDLM_MASK_ID >= 0 else (vocab_size - 1)
|
| 486 |
+
x_noised, mask_positions, loss_weights = mdlm_masked_forward_process(
|
| 487 |
+
y, mask_token_id=_mask_id, alpha_schedule=MDLM_SCHEDULE,
|
| 488 |
+
)
|
| 489 |
+
with autocast_ctx:
|
| 490 |
+
logits = model(x_noised) # targets=None -> (B, T, V) logits
|
| 491 |
+
loss = mdlm_rb_loss(logits, y, mask_positions, loss_weights)
|
| 492 |
+
else:
|
| 493 |
+
with autocast_ctx:
|
| 494 |
+
loss = model(x, y)
|
| 495 |
+
train_loss = loss.detach()
|
| 496 |
+
loss = loss / grad_accum_steps
|
| 497 |
+
loss.backward()
|
| 498 |
+
if _prof:
|
| 499 |
+
torch.cuda.synchronize()
|
| 500 |
+
_gpu_ms += (time.time() - _t_micro) * 1000
|
| 501 |
+
_t_data = time.time()
|
| 502 |
+
x, y, epoch = next(train_loader)
|
| 503 |
+
if _prof:
|
| 504 |
+
_data_ms += (time.time() - _t_data) * 1000
|
| 505 |
+
if _prof:
|
| 506 |
+
torch.cuda.synchronize(); _t_fb = time.time()
|
| 507 |
+
|
| 508 |
+
# Progress and schedules
|
| 509 |
+
progress = min(total_training_time / TIME_BUDGET, 1.0)
|
| 510 |
+
lrm = get_lr_multiplier(progress)
|
| 511 |
+
muon_momentum = get_muon_momentum(step)
|
| 512 |
+
muon_weight_decay = get_weight_decay(progress)
|
| 513 |
+
for group in optimizer.param_groups:
|
| 514 |
+
group["lr"] = group["initial_lr"] * lrm
|
| 515 |
+
if group['kind'] == 'muon':
|
| 516 |
+
group["momentum"] = muon_momentum
|
| 517 |
+
group["weight_decay"] = muon_weight_decay
|
| 518 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 519 |
+
optimizer.step()
|
| 520 |
+
if _prof:
|
| 521 |
+
torch.cuda.synchronize(); _t_opt = time.time()
|
| 522 |
+
|
| 523 |
+
# Learnability #2: EMA update after every optimizer step.
|
| 524 |
+
if ema_model is not None:
|
| 525 |
+
try:
|
| 526 |
+
ema_model.update_parameters(model)
|
| 527 |
+
except Exception as _e:
|
| 528 |
+
print(f"[EMA] update failed at step {step}: {_e}", flush=True)
|
| 529 |
+
|
| 530 |
+
# Learnability #7: curriculum transition. After
|
| 531 |
+
# CURRICULUM_SHORT_STEPS optimizer steps, rebuild the dataloader at
|
| 532 |
+
# MAX_SEQ_LEN. Done once, then the flag flips off.
|
| 533 |
+
if _curriculum_active and step + 1 >= CURRICULUM_SHORT_STEPS:
|
| 534 |
+
print(
|
| 535 |
+
f"[CURRICULUM] step={step+1} — switching from T={_current_seq_len} "
|
| 536 |
+
f"to T={MAX_SEQ_LEN}",
|
| 537 |
+
flush=True,
|
| 538 |
+
)
|
| 539 |
+
_current_seq_len = MAX_SEQ_LEN
|
| 540 |
+
_curriculum_active = False
|
| 541 |
+
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
|
| 542 |
+
# Prefetch the next batch at the new seq_len so the following
|
| 543 |
+
# loop iteration consumes fresh data.
|
| 544 |
+
x, y, epoch = next(train_loader)
|
| 545 |
+
|
| 546 |
+
# Online SOM update — retina is now a plain Python attribute (not a
|
| 547 |
+
# registered buffer) so mutations do not invalidate torch.compile guards.
|
| 548 |
+
# Runs fully on CPU; safe to overlap with GPU forward pass.
|
| 549 |
+
_last_sdr = getattr(model, "_last_sdr", None)
|
| 550 |
+
if _last_sdr is not None:
|
| 551 |
+
if _ASYNC_POSTPROCESS:
|
| 552 |
+
if _som_thread is not None:
|
| 553 |
+
_som_thread.join()
|
| 554 |
+
# Clone tensors before next step overwrites them
|
| 555 |
+
_som_x = x.clone()
|
| 556 |
+
_som_sdr = _last_sdr.clone()
|
| 557 |
+
_som_thread = threading.Thread(
|
| 558 |
+
target=model.sdr_semantic.maybe_som_update,
|
| 559 |
+
args=(_som_x, _som_sdr),
|
| 560 |
+
daemon=True,
|
| 561 |
+
)
|
| 562 |
+
_som_thread.start()
|
| 563 |
+
else:
|
| 564 |
+
model.sdr_semantic.maybe_som_update(x, _last_sdr)
|
| 565 |
+
|
| 566 |
+
# Hestia QAT — anneal temperature every step, snap every N steps.
|
| 567 |
+
# apply_to walks all Linear modules (CPU) then does .data.copy_ (GPU).
|
| 568 |
+
# Background thread + separate CUDA stream lets this overlap with
|
| 569 |
+
# the next forward pass on the default stream.
|
| 570 |
+
_hestia_progress = (time.time() - t_start_training) / max(TIME_BUDGET, 1)
|
| 571 |
+
_hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
|
| 572 |
+
if step % _hestia_interval == 0:
|
| 573 |
+
if _ASYNC_POSTPROCESS:
|
| 574 |
+
if _hestia_thread is not None:
|
| 575 |
+
_hestia_thread.join()
|
| 576 |
+
|
| 577 |
+
def _hestia_bg(mdl: torch.nn.Module, prog: float) -> None:
|
| 578 |
+
assert _hestia_stream is not None
|
| 579 |
+
with torch.cuda.stream(_hestia_stream):
|
| 580 |
+
mdl.hestia.anneal_temperature(prog)
|
| 581 |
+
mdl.hestia.apply_to(mdl)
|
| 582 |
+
|
| 583 |
+
_hestia_thread = threading.Thread(
|
| 584 |
+
target=_hestia_bg,
|
| 585 |
+
args=(model, _hestia_progress),
|
| 586 |
+
daemon=True,
|
| 587 |
+
)
|
| 588 |
+
_hestia_thread.start()
|
| 589 |
+
else:
|
| 590 |
+
model.hestia.anneal_temperature(_hestia_progress)
|
| 591 |
+
model.hestia.apply_to(model)
|
| 592 |
+
else:
|
| 593 |
+
# anneal_temperature is cheap (~1 us), keep inline
|
| 594 |
+
model.hestia.anneal_temperature(_hestia_progress)
|
| 595 |
+
|
| 596 |
+
model.zero_grad(set_to_none=True)
|
| 597 |
+
|
| 598 |
+
train_loss_f = train_loss.item()
|
| 599 |
+
if math.isnan(train_loss_f) or train_loss_f > 100:
|
| 600 |
+
print("FAIL")
|
| 601 |
+
# Save to a DIFFERENT file — never clobber a good latest.pt with
|
| 602 |
+
# a NaN/diverged state. The good ckpt from the last periodic save
|
| 603 |
+
# is the right place to resume from.
|
| 604 |
+
save_ckpt(
|
| 605 |
+
model,
|
| 606 |
+
optimizer,
|
| 607 |
+
config,
|
| 608 |
+
step,
|
| 609 |
+
total_training_time,
|
| 610 |
+
smooth_train_loss,
|
| 611 |
+
bpt_ema,
|
| 612 |
+
epoch,
|
| 613 |
+
FAILED_CKPT,
|
| 614 |
+
blocking=True,
|
| 615 |
+
)
|
| 616 |
+
raise SystemExit(1)
|
| 617 |
+
|
| 618 |
+
torch.cuda.synchronize()
|
| 619 |
+
t1 = time.time()
|
| 620 |
+
dt = t1 - t0
|
| 621 |
+
|
| 622 |
+
if _prof:
|
| 623 |
+
fb = (_t_fb - t0) * 1000
|
| 624 |
+
opt = (_t_opt - _t_fb) * 1000
|
| 625 |
+
rest = (t1 - _t_opt) * 1000
|
| 626 |
+
print(
|
| 627 |
+
f"[PROF step={step:05d}] gpu={_gpu_ms:.0f}ms data_fetch={_data_ms:.0f}ms "
|
| 628 |
+
f"(sum_fb={fb:.0f}) opt={opt:.0f}ms rest={rest:.0f}ms total={dt*1000:.0f}ms",
|
| 629 |
+
flush=True,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if step > 10:
|
| 633 |
+
total_training_time += dt
|
| 634 |
+
|
| 635 |
+
ema_beta = 0.9
|
| 636 |
+
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
|
| 637 |
+
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta ** (step + 1))
|
| 638 |
+
pct_done = 100 * progress
|
| 639 |
+
tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
|
| 640 |
+
mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / GPU_BF16_PEAK_FLOPS
|
| 641 |
+
remaining = max(0, TIME_BUDGET - total_training_time)
|
| 642 |
+
|
| 643 |
+
# Bytes-per-token for the CURRENT batch. evaluate_bpb in prepare.py
|
| 644 |
+
# computes bits-per-BYTE (total_nats / (ln2 * total_bytes)); to match
|
| 645 |
+
# that semantics live, we EMA-smooth the per-batch bytes/token and
|
| 646 |
+
# divide. Without this, the old `bpb = loss/ln2` was actually
|
| 647 |
+
# bits-per-token — ~4× larger than val_bpb at vocab=8192 and
|
| 648 |
+
# therefore not comparable to the champion 1.279 bpb metric.
|
| 649 |
+
with torch.no_grad():
|
| 650 |
+
y_flat = y.view(-1)
|
| 651 |
+
nbytes_batch = token_bytes[y_flat]
|
| 652 |
+
mask = nbytes_batch > 0
|
| 653 |
+
mask_count = mask.sum().clamp(min=1).float()
|
| 654 |
+
avg_bytes_per_tok = (nbytes_batch.float() * mask.float()).sum() / mask_count
|
| 655 |
+
bpt_batch = float(avg_bytes_per_tok.item())
|
| 656 |
+
if step == 0 or bpt_ema <= 0.0:
|
| 657 |
+
bpt_ema = bpt_batch
|
| 658 |
+
else:
|
| 659 |
+
bpt_ema = 0.98 * bpt_ema + 0.02 * bpt_batch
|
| 660 |
+
|
| 661 |
+
# Dual metric: bpb (byte-normalized, comparable with val_bpb) AND
|
| 662 |
+
# bpt (bits per token, the raw loss in bits). bpt_div exposes the
|
| 663 |
+
# current avg bytes-per-token so the conversion is transparent.
|
| 664 |
+
bpt = debiased_smooth_loss / math.log(2)
|
| 665 |
+
bpb = bpt / max(bpt_ema, 1e-6)
|
| 666 |
+
vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
|
| 667 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 668 |
+
|
| 669 |
+
# Per-step line-buffered log. NOT \r-overwritten so tee/grep see it.
|
| 670 |
+
# Keep key=value pairs grep-friendly.
|
| 671 |
+
ppl = 2.0 ** bpb # perplexity (byte-level)
|
| 672 |
+
print(
|
| 673 |
+
f"step={step:05d} loss={debiased_smooth_loss:.4f} bpb={bpb:.4f} ppl={ppl:.3f} "
|
| 674 |
+
f"bpt={bpt:.3f} bpt_div={bpt_ema:.2f} "
|
| 675 |
+
f"tps={tok_per_sec} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
|
| 676 |
+
f"lr={current_lr:.2e} vram={vram_mib:.0f}MiB "
|
| 677 |
+
f"pct={pct_done:.1f} epoch={epoch} remaining={remaining:.0f}s",
|
| 678 |
+
flush=True,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
if step == 0:
|
| 682 |
+
gc.collect()
|
| 683 |
+
gc.freeze()
|
| 684 |
+
gc.disable()
|
| 685 |
+
# No periodic gc.collect() — we disabled+froze at step 0 on purpose,
|
| 686 |
+
# so a manual collect every 5k steps just re-scans frozen objects
|
| 687 |
+
# (burned ~900 ms/event in production) for no live-garbage reason.
|
| 688 |
+
|
| 689 |
+
if CKPT_INTERVAL > 0 and step > 0 and step % CKPT_INTERVAL == 0:
|
| 690 |
+
save_ckpt(
|
| 691 |
+
model,
|
| 692 |
+
optimizer,
|
| 693 |
+
config,
|
| 694 |
+
step,
|
| 695 |
+
total_training_time,
|
| 696 |
+
smooth_train_loss,
|
| 697 |
+
bpt_ema,
|
| 698 |
+
epoch,
|
| 699 |
+
LATEST_CKPT,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Periodic mid-training validation so we can see the model learning
|
| 703 |
+
# English in real time (not just at the end). Small val batch so it
|
| 704 |
+
# doesn't eat significant training time.
|
| 705 |
+
mid_val_interval = int(os.environ.get("HYDRA_MID_VAL_INTERVAL", "500"))
|
| 706 |
+
if mid_val_interval > 0 and step > 0 and step % mid_val_interval == 0:
|
| 707 |
+
model.eval()
|
| 708 |
+
try:
|
| 709 |
+
# Defrag GPU memory before eval allocates fresh chunks —
|
| 710 |
+
# without this the eval path can OOM on 6GB cards even
|
| 711 |
+
# though total usage fits, because the allocator's free
|
| 712 |
+
# blocks are fragmented.
|
| 713 |
+
torch.cuda.empty_cache()
|
| 714 |
+
_orig_mid = _prepare_mod.EVAL_TOKENS
|
| 715 |
+
_prepare_mod.EVAL_TOKENS = 262144 # ~260K tokens, fast
|
| 716 |
+
with torch.no_grad():
|
| 717 |
+
with autocast_ctx:
|
| 718 |
+
mid_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
|
| 719 |
+
_prepare_mod.EVAL_TOKENS = _orig_mid
|
| 720 |
+
mid_ppl = 2.0 ** mid_bpb
|
| 721 |
+
print(f"[MID_VAL] step={step} val_bpb={mid_bpb:.4f} val_ppl={mid_ppl:.3f}", flush=True)
|
| 722 |
+
|
| 723 |
+
# Per-layer diagnostic panel. Only printed when HYDRA_LAYER_DIAGNOSTICS=1
|
| 724 |
+
# is set (otherwise the layer_* keys are absent from _metrics).
|
| 725 |
+
_diag_metrics = model.get_secondary_metrics()
|
| 726 |
+
_layer_keys = sorted([k for k in _diag_metrics.keys() if k.startswith('layer_')])
|
| 727 |
+
if _layer_keys:
|
| 728 |
+
# Condense: one row per layer showing the four core signals.
|
| 729 |
+
n_layers = len(model.blocks)
|
| 730 |
+
print(f"[LAYER_DIAG] step={step}", flush=True)
|
| 731 |
+
for li in range(n_layers):
|
| 732 |
+
d_ratio = _diag_metrics.get(f'layer_{li}_delta_ratio', float('nan'))
|
| 733 |
+
out_n = _diag_metrics.get(f'layer_{li}_out_norm', float('nan'))
|
| 734 |
+
g_norm = _diag_metrics.get(f'layer_{li}_grad_norm', float('nan'))
|
| 735 |
+
eff_r = _diag_metrics.get(f'layer_{li}_eff_rank', float('nan'))
|
| 736 |
+
f_std = _diag_metrics.get(f'layer_{li}_feat_std', float('nan'))
|
| 737 |
+
print(
|
| 738 |
+
f"[LAYER_DIAG] L{li:02d} delta_ratio={d_ratio:.4f} "
|
| 739 |
+
f"out_norm={out_n:.4f} grad_norm={g_norm:.3e} "
|
| 740 |
+
f"eff_rank={eff_r:.1f} feat_std={f_std:.4f}",
|
| 741 |
+
flush=True,
|
| 742 |
+
)
|
| 743 |
+
htm_proj_g = _diag_metrics.get('htm_proj_grad_norm', None)
|
| 744 |
+
if htm_proj_g is not None:
|
| 745 |
+
print(f"[LAYER_DIAG] htm_proj grad_norm={htm_proj_g:.3e}", flush=True)
|
| 746 |
+
except Exception as e:
|
| 747 |
+
print(f"[MID_VAL] failed: {e}", flush=True)
|
| 748 |
+
model.train()
|
| 749 |
+
|
| 750 |
+
step += 1
|
| 751 |
+
|
| 752 |
+
if step > 10 and total_training_time >= TIME_BUDGET:
|
| 753 |
+
break
|
| 754 |
+
|
| 755 |
+
# Drain async postprocessing threads before eval
|
| 756 |
+
if _som_thread is not None:
|
| 757 |
+
_som_thread.join()
|
| 758 |
+
if _hestia_thread is not None:
|
| 759 |
+
_hestia_thread.join()
|
| 760 |
+
if _hestia_stream is not None:
|
| 761 |
+
_hestia_stream.synchronize()
|
| 762 |
+
|
| 763 |
+
total_tokens = step * TOTAL_BATCH_SIZE
|
| 764 |
+
|
| 765 |
+
# ----------------------------------------------------------------------
|
| 766 |
+
# SAVE ORDER (critical):
|
| 767 |
+
# 1. Save PRETRAIN_FINAL_CKPT with val_bpb=None (hedge against eval OOM)
|
| 768 |
+
# 2. Save LATEST_CKPT with val_bpb=None (hedge against eval OOM)
|
| 769 |
+
# 3. Run eval (may OOM on small GPUs; we survive it)
|
| 770 |
+
# 4. Re-save both ckpts with val_bpb filled in
|
| 771 |
+
# This way we NEVER lose the final trained weights to an eval crash.
|
| 772 |
+
# Previous ordering put eval first, so an eval-time OOM destroyed the
|
| 773 |
+
# only record of a 6h training run (2026-04-22 incident).
|
| 774 |
+
# ----------------------------------------------------------------------
|
| 775 |
+
|
| 776 |
+
save_ckpt(
|
| 777 |
+
model, optimizer, config, step, total_training_time,
|
| 778 |
+
smooth_train_loss, bpt_ema, epoch, PRETRAIN_FINAL_CKPT,
|
| 779 |
+
val_bpb=None, blocking=True,
|
| 780 |
+
)
|
| 781 |
+
save_ckpt(
|
| 782 |
+
model, optimizer, config, step, total_training_time,
|
| 783 |
+
smooth_train_loss, bpt_ema, epoch, LATEST_CKPT,
|
| 784 |
+
val_bpb=None, blocking=True,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# Now it's safe to eval — ckpts are on disk regardless of what happens here.
|
| 788 |
+
# HYDRA_EVAL_BATCH overrides DEVICE_BATCH_SIZE (env-tunable; default halves
|
| 789 |
+
# the training batch because eval holds activations for full sequence and
|
| 790 |
+
# does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
|
| 791 |
+
# how many val tokens to sweep (default 2 M, short enough for autoresearch
|
| 792 |
+
# 5-min budgets).
|
| 793 |
+
val_bpb: float | None = None
|
| 794 |
+
# Eval batch: default to 4 on cloud GPUs (enough freed VRAM after optimizer
|
| 795 |
+
# clear), fall back to DEVICE_BATCH_SIZE//2 on tiny cards. Env-overridable.
|
| 796 |
+
_eval_B = int(os.environ.get("HYDRA_EVAL_BATCH",
|
| 797 |
+
str(max(1, DEVICE_BATCH_SIZE // 2) if DEVICE_BATCH_SIZE <= 8 else 4)))
|
| 798 |
+
# Eval tokens: default 1M (1,048,576) — gives statistically meaningful BPB
|
| 799 |
+
# (256 forward passes at B=4, seq=1024). Env-overridable for fast/slow sweeps.
|
| 800 |
+
_eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(1048576)))
|
| 801 |
+
try:
|
| 802 |
+
# Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
|
| 803 |
+
# which leaves < 1GB for the eval forward — the driver can't satisfy
|
| 804 |
+
# the allocation. Free EVERY tensor we don't strictly need:
|
| 805 |
+
# - optimizer grads (set_to_none releases tensor)
|
| 806 |
+
# - optimizer.state (fp32 Muon NS workspace, AdamW moments — ~size-of-params each)
|
| 807 |
+
# - model internal caches (HTM subsample cache, SDR stash)
|
| 808 |
+
# After this, VRAM should be ~params only (bf16 ≈ 120MB at 60M params).
|
| 809 |
+
optimizer.zero_grad(set_to_none=True)
|
| 810 |
+
if hasattr(optimizer, 'state') and optimizer.state:
|
| 811 |
+
for p, st in list(optimizer.state.items()):
|
| 812 |
+
st.clear()
|
| 813 |
+
optimizer.state.clear()
|
| 814 |
+
for p in model.parameters():
|
| 815 |
+
if p.grad is not None:
|
| 816 |
+
p.grad = None
|
| 817 |
+
if hasattr(model, '_htm_cache'):
|
| 818 |
+
model._htm_cache = None
|
| 819 |
+
if hasattr(model, '_last_sdr'):
|
| 820 |
+
model._last_sdr = None
|
| 821 |
+
import gc as _gc
|
| 822 |
+
_gc.collect()
|
| 823 |
+
torch.cuda.empty_cache()
|
| 824 |
+
torch.cuda.synchronize()
|
| 825 |
+
try:
|
| 826 |
+
_free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
|
| 827 |
+
print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
|
| 828 |
+
except Exception:
|
| 829 |
+
pass
|
| 830 |
+
print(f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B}...", flush=True)
|
| 831 |
+
model.eval()
|
| 832 |
+
_orig = _prepare_mod.EVAL_TOKENS
|
| 833 |
+
_prepare_mod.EVAL_TOKENS = _eval_tokens
|
| 834 |
+
# Nemotron path reads HYDRA_STREAM_EVAL_TOKENS env var directly,
|
| 835 |
+
# not _prepare_mod.EVAL_TOKENS. Sync both so eval budget is
|
| 836 |
+
# respected regardless of which dataloader path is active.
|
| 837 |
+
_orig_stream = os.environ.get("HYDRA_STREAM_EVAL_TOKENS")
|
| 838 |
+
os.environ["HYDRA_STREAM_EVAL_TOKENS"] = str(_eval_tokens)
|
| 839 |
+
with autocast_ctx:
|
| 840 |
+
val_bpb = evaluate_bpb(model, tokenizer, _eval_B)
|
| 841 |
+
_prepare_mod.EVAL_TOKENS = _orig
|
| 842 |
+
if _orig_stream is not None:
|
| 843 |
+
os.environ["HYDRA_STREAM_EVAL_TOKENS"] = _orig_stream
|
| 844 |
+
else:
|
| 845 |
+
os.environ.pop("HYDRA_STREAM_EVAL_TOKENS", None)
|
| 846 |
+
val_ppl = 2 ** val_bpb
|
| 847 |
+
print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
|
| 848 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 849 |
+
print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
|
| 850 |
+
torch.cuda.empty_cache()
|
| 851 |
+
except Exception as e:
|
| 852 |
+
import traceback as _tb
|
| 853 |
+
print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
|
| 854 |
+
_tb.print_exc()
|
| 855 |
+
try:
|
| 856 |
+
_free = torch.cuda.mem_get_info()[0] / 1024 / 1024
|
| 857 |
+
print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
|
| 858 |
+
except Exception:
|
| 859 |
+
pass
|
| 860 |
+
|
| 861 |
+
# Final ckpts with val_bpb filled in (if eval succeeded).
|
| 862 |
+
save_ckpt(
|
| 863 |
+
model, optimizer, config, step, total_training_time,
|
| 864 |
+
smooth_train_loss, bpt_ema, epoch, LATEST_CKPT,
|
| 865 |
+
val_bpb=val_bpb, blocking=True,
|
| 866 |
+
)
|
| 867 |
+
save_ckpt(
|
| 868 |
+
model, optimizer, config, step, total_training_time,
|
| 869 |
+
smooth_train_loss, bpt_ema, epoch, PRETRAIN_FINAL_CKPT,
|
| 870 |
+
val_bpb=val_bpb, blocking=True,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
# Learnability #2: persist EMA weights alongside the raw checkpoint.
|
| 874 |
+
# latest_ema.pt contains ema_model.module (the Averaged params) so it
|
| 875 |
+
# can be loaded by evaluation / inference code that expects the same
|
| 876 |
+
# state_dict shape as the raw model.
|
| 877 |
+
if ema_model is not None:
|
| 878 |
+
try:
|
| 879 |
+
ema_ckpt_path = CACHE_DIR / "latest_ema.pt"
|
| 880 |
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 881 |
+
torch.save({
|
| 882 |
+
"model_state_dict": ema_model.module.state_dict(),
|
| 883 |
+
"config": asdict(config),
|
| 884 |
+
"step": step,
|
| 885 |
+
"epoch": epoch,
|
| 886 |
+
"train_seconds": total_training_time,
|
| 887 |
+
"val_bpb": val_bpb,
|
| 888 |
+
"ema_decay": EMA_DECAY,
|
| 889 |
+
}, str(ema_ckpt_path))
|
| 890 |
+
print(f"[EMA] saved {ema_ckpt_path} (step={step})", flush=True)
|
| 891 |
+
except Exception as _e:
|
| 892 |
+
print(f"[EMA] save failed: {_e}", flush=True)
|
| 893 |
+
|
| 894 |
+
run_factual_probes(model, tokenizer, device, autocast_ctx)
|
| 895 |
+
|
| 896 |
+
t_end = time.time()
|
| 897 |
+
startup_time = t_start_training - t_start
|
| 898 |
+
steady_state_mfu = (
|
| 899 |
+
100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10)
|
| 900 |
+
/ total_training_time / GPU_BF16_PEAK_FLOPS
|
| 901 |
+
if total_training_time > 0 else 0
|
| 902 |
+
)
|
| 903 |
+
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 904 |
+
metrics = model.get_secondary_metrics()
|
| 905 |
+
|
| 906 |
+
print("---")
|
| 907 |
+
print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
|
| 908 |
+
print(f"training_seconds: {total_training_time:.1f}")
|
| 909 |
+
print(f"total_seconds: {t_end - t_start:.1f}")
|
| 910 |
+
print(f"peak_vram_mb: {peak_vram_mb:.1f}")
|
| 911 |
+
print(f"mfu_percent: {steady_state_mfu:.2f}")
|
| 912 |
+
print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
|
| 913 |
+
print(f"num_steps: {step}")
|
| 914 |
+
print(f"num_params_M: {num_params / 1e6:.1f}")
|
| 915 |
+
print(f"n_layer: {N_LAYER}")
|
| 916 |
+
print(f"d_model: {D_MODEL}")
|
| 917 |
+
print(f"engram_hit_rate: {metrics.get('engram_hit_rate', 0.0):.4f}")
|
| 918 |
+
print(f"sdr_active_bits: {metrics.get('sdr_active_bits', 0):.1f}")
|
| 919 |
+
print(f"htm_anomaly: {metrics.get('htm_anomaly', 0):.4f}")
|
| 920 |
+
|
| 921 |
+
# Per-layer summary panel — only printed when diagnostics were active.
|
| 922 |
+
_layer_keys = sorted([k for k in metrics.keys() if k.startswith('layer_')])
|
| 923 |
+
if _layer_keys:
|
| 924 |
+
n_layers = len(model.blocks)
|
| 925 |
+
print("--- per-layer diagnostic panel ---")
|
| 926 |
+
for li in range(n_layers):
|
| 927 |
+
d_ratio = metrics.get(f'layer_{li}_delta_ratio', float('nan'))
|
| 928 |
+
out_n = metrics.get(f'layer_{li}_out_norm', float('nan'))
|
| 929 |
+
g_norm = metrics.get(f'layer_{li}_grad_norm', float('nan'))
|
| 930 |
+
eff_r = metrics.get(f'layer_{li}_eff_rank', float('nan'))
|
| 931 |
+
f_std = metrics.get(f'layer_{li}_feat_std', float('nan'))
|
| 932 |
+
print(
|
| 933 |
+
f"L{li:02d} delta_ratio={d_ratio:.4f} out_norm={out_n:.4f} "
|
| 934 |
+
f"grad_norm={g_norm:.3e} eff_rank={eff_r:.1f} feat_std={f_std:.4f}"
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
# Emit full metrics dictionary as JSON for sweep aggregation. Path from
|
| 938 |
+
# HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
|
| 939 |
+
# written (even without diagnostics) so the aggregator can compare runs.
|
| 940 |
+
_metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
|
| 941 |
+
try:
|
| 942 |
+
_dump = dict(metrics)
|
| 943 |
+
_dump.update({
|
| 944 |
+
'val_bpb': (float(val_bpb) if val_bpb is not None else None),
|
| 945 |
+
'val_ppl': (float(val_ppl) if val_ppl is not None else None),
|
| 946 |
+
'n_layer': int(N_LAYER),
|
| 947 |
+
'd_model': int(D_MODEL),
|
| 948 |
+
'num_params_M': float(num_params / 1e6),
|
| 949 |
+
'num_steps': int(step),
|
| 950 |
+
'total_tokens_M': float(total_tokens / 1e6),
|
| 951 |
+
'peak_vram_mb': float(peak_vram_mb),
|
| 952 |
+
'training_seconds': float(total_training_time),
|
| 953 |
+
'sdr_target_active': int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
|
| 954 |
+
})
|
| 955 |
+
Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
|
| 956 |
+
with open(_metrics_out, 'w') as _f:
|
| 957 |
+
json.dump(_dump, _f, indent=2, sort_keys=True)
|
| 958 |
+
print(f"[METRICS] wrote {_metrics_out}", flush=True)
|
| 959 |
+
# Also emit a single-line JSON to stdout so the sweep aggregator can
|
| 960 |
+
# scrape it from HF Jobs logs without pulling files out of the container.
|
| 961 |
+
print("[METRICS_JSON] " + json.dumps(_dump, sort_keys=True), flush=True)
|
| 962 |
+
except Exception as _e:
|
| 963 |
+
print(f"[METRICS] write failed: {_e}", flush=True)
|
| 964 |
+
|
| 965 |
+
run_factual_english(model, tokenizer, MAX_SEQ_LEN)
|
| 966 |
+
# startup_time is informative but not printed (preserve historical output)
|
| 967 |
+
_ = startup_time
|
overlay/kernels/cuda/decode_kernels.cu
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
/*
|
| 2 |
-
* CuTe DSL decode kernels for Mamba-3 autoregressive generation.
|
| 3 |
-
*
|
| 4 |
-
* Phase 2: Optimized single-token SSM step for inference.
|
| 5 |
-
* Phase 1: Not needed (training only, no generation).
|
| 6 |
-
*
|
| 7 |
-
* Fuses: input_proj + conv_step + ssm_step + output_proj
|
| 8 |
-
* into a single kernel launch for minimal latency.
|
| 9 |
-
*/
|
| 10 |
-
// Stub: Phase 2 implementation
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* CuTe DSL decode kernels for Mamba-3 autoregressive generation.
|
| 3 |
+
*
|
| 4 |
+
* Phase 2: Optimized single-token SSM step for inference.
|
| 5 |
+
* Phase 1: Not needed (training only, no generation).
|
| 6 |
+
*
|
| 7 |
+
* Fuses: input_proj + conv_step + ssm_step + output_proj
|
| 8 |
+
* into a single kernel launch for minimal latency.
|
| 9 |
+
*/
|
| 10 |
+
// Stub: Phase 2 implementation
|
overlay/kernels/cuda/flashfftconv/LICENSE
CHANGED
|
@@ -1,201 +1,201 @@
|
|
| 1 |
-
Apache License
|
| 2 |
-
Version 2.0, January 2004
|
| 3 |
-
http://www.apache.org/licenses/
|
| 4 |
-
|
| 5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
-
|
| 7 |
-
1. Definitions.
|
| 8 |
-
|
| 9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
-
|
| 12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
-
the copyright owner that is granting the License.
|
| 14 |
-
|
| 15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
-
other entities that control, are controlled by, or are under common
|
| 17 |
-
control with that entity. For the purposes of this definition,
|
| 18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
-
direction or management of such entity, whether by contract or
|
| 20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
-
|
| 23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
-
exercising permissions granted by this License.
|
| 25 |
-
|
| 26 |
-
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
-
including but not limited to software source code, documentation
|
| 28 |
-
source, and configuration files.
|
| 29 |
-
|
| 30 |
-
"Object" form shall mean any form resulting from mechanical
|
| 31 |
-
transformation or translation of a Source form, including but
|
| 32 |
-
not limited to compiled object code, generated documentation,
|
| 33 |
-
and conversions to other media types.
|
| 34 |
-
|
| 35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
-
Object form, made available under the License, as indicated by a
|
| 37 |
-
copyright notice that is included in or attached to the work
|
| 38 |
-
(an example is provided in the Appendix below).
|
| 39 |
-
|
| 40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
-
form, that is based on (or derived from) the Work and for which the
|
| 42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
-
of this License, Derivative Works shall not include works that remain
|
| 45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
-
the Work and Derivative Works thereof.
|
| 47 |
-
|
| 48 |
-
"Contribution" shall mean any work of authorship, including
|
| 49 |
-
the original version of the Work and any modifications or additions
|
| 50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
-
means any form of electronic, verbal, or written communication sent
|
| 55 |
-
to the Licensor or its representatives, including but not limited to
|
| 56 |
-
communication on electronic mailing lists, source code control systems,
|
| 57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
-
excluding communication that is conspicuously marked or otherwise
|
| 60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
-
|
| 62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
-
subsequently incorporated within the Work.
|
| 65 |
-
|
| 66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
-
Work and such Derivative Works in Source or Object form.
|
| 72 |
-
|
| 73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
-
(except as stated in this section) patent license to make, have made,
|
| 77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
-
where such license applies only to those patent claims licensable
|
| 79 |
-
by such Contributor that are necessarily infringed by their
|
| 80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
-
institute patent litigation against any entity (including a
|
| 83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
-
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
-
or contributory patent infringement, then any patent licenses
|
| 86 |
-
granted to You under this License for that Work shall terminate
|
| 87 |
-
as of the date such litigation is filed.
|
| 88 |
-
|
| 89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
-
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
-
modifications, and in Source or Object form, provided that You
|
| 92 |
-
meet the following conditions:
|
| 93 |
-
|
| 94 |
-
(a) You must give any other recipients of the Work or
|
| 95 |
-
Derivative Works a copy of this License; and
|
| 96 |
-
|
| 97 |
-
(b) You must cause any modified files to carry prominent notices
|
| 98 |
-
stating that You changed the files; and
|
| 99 |
-
|
| 100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
-
that You distribute, all copyright, patent, trademark, and
|
| 102 |
-
attribution notices from the Source form of the Work,
|
| 103 |
-
excluding those notices that do not pertain to any part of
|
| 104 |
-
the Derivative Works; and
|
| 105 |
-
|
| 106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
-
distribution, then any Derivative Works that You distribute must
|
| 108 |
-
include a readable copy of the attribution notices contained
|
| 109 |
-
within such NOTICE file, excluding those notices that do not
|
| 110 |
-
pertain to any part of the Derivative Works, in at least one
|
| 111 |
-
of the following places: within a NOTICE text file distributed
|
| 112 |
-
as part of the Derivative Works; within the Source form or
|
| 113 |
-
documentation, if provided along with the Derivative Works; or,
|
| 114 |
-
within a display generated by the Derivative Works, if and
|
| 115 |
-
wherever such third-party notices normally appear. The contents
|
| 116 |
-
of the NOTICE file are for informational purposes only and
|
| 117 |
-
do not modify the License. You may add Your own attribution
|
| 118 |
-
notices within Derivative Works that You distribute, alongside
|
| 119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
-
that such additional attribution notices cannot be construed
|
| 121 |
-
as modifying the License.
|
| 122 |
-
|
| 123 |
-
You may add Your own copyright statement to Your modifications and
|
| 124 |
-
may provide additional or different license terms and conditions
|
| 125 |
-
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
-
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
-
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
-
the conditions stated in this License.
|
| 129 |
-
|
| 130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
-
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
-
this License, without any additional terms or conditions.
|
| 134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
-
the terms of any separate license agreement you may have executed
|
| 136 |
-
with Licensor regarding such Contributions.
|
| 137 |
-
|
| 138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
-
except as required for reasonable and customary use in describing the
|
| 141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
-
|
| 143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
-
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
-
implied, including, without limitation, any warranties or conditions
|
| 148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
-
appropriateness of using or redistributing the Work and assume any
|
| 151 |
-
risks associated with Your exercise of permissions under this License.
|
| 152 |
-
|
| 153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
-
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
-
unless required by applicable law (such as deliberate and grossly
|
| 156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
-
liable to You for damages, including any direct, indirect, special,
|
| 158 |
-
incidental, or consequential damages of any character arising as a
|
| 159 |
-
result of this License or out of the use or inability to use the
|
| 160 |
-
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
-
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
-
other commercial damages or losses), even if such Contributor
|
| 163 |
-
has been advised of the possibility of such damages.
|
| 164 |
-
|
| 165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
-
or other liability obligations and/or rights consistent with this
|
| 169 |
-
License. However, in accepting such obligations, You may act only
|
| 170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
-
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
-
defend, and hold each Contributor harmless for any liability
|
| 173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
-
of your accepting any such warranty or additional liability.
|
| 175 |
-
|
| 176 |
-
END OF TERMS AND CONDITIONS
|
| 177 |
-
|
| 178 |
-
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
-
|
| 180 |
-
To apply the Apache License to your work, attach the following
|
| 181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
-
replaced with your own identifying information. (Don't include
|
| 183 |
-
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
-
comment syntax for the file format. We also recommend that a
|
| 185 |
-
file or class name and description of purpose be included on the
|
| 186 |
-
same "printed page" as the copyright notice for easier
|
| 187 |
-
identification within third-party archives.
|
| 188 |
-
|
| 189 |
-
Copyright [yyyy] [name of copyright owner]
|
| 190 |
-
|
| 191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
-
you may not use this file except in compliance with the License.
|
| 193 |
-
You may obtain a copy of the License at
|
| 194 |
-
|
| 195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
-
|
| 197 |
-
Unless required by applicable law or agreed to in writing, software
|
| 198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
-
See the License for the specific language governing permissions and
|
| 201 |
-
limitations under the License.
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
overlay/kernels/cuda/flashfftconv/README.md
CHANGED
|
@@ -1,57 +1,57 @@
|
|
| 1 |
-
# flashfftconv (vendored)
|
| 2 |
-
|
| 3 |
-
Vendored from https://github.com/HazyResearch/flash-fft-conv (Apache 2.0 license).
|
| 4 |
-
|
| 5 |
-
**Upstream commit:** see `UPSTREAM_COMMIT`.
|
| 6 |
-
|
| 7 |
-
## What this is
|
| 8 |
-
|
| 9 |
-
HazyResearch's Monarch-matrix-decomposition FFT convolution CUDA kernel. Provides a
|
| 10 |
-
drop-in replacement for `torch.fft.rfft + complex-mult + irfft` that runs ~2-3x
|
| 11 |
-
faster than cuFFT for the specific power-of-two lengths it supports (256, 512,
|
| 12 |
-
1024, 2048, 4096, 8192, ..., up to 4M).
|
| 13 |
-
|
| 14 |
-
In HYDRA, we use it to accelerate `subsystems/hyena_pure.fftconv_ref`. The
|
| 15 |
-
accelerated path is opt-in via `HYDRA_HYENA_FLASH_FFT=1`; default behavior is
|
| 16 |
-
unchanged (pure PyTorch fallback).
|
| 17 |
-
|
| 18 |
-
## How to build
|
| 19 |
-
|
| 20 |
-
The vendored tree contains:
|
| 21 |
-
- `flashfftconv/` — pure-Python wrappers (imports `monarch_cuda` CUDA extension)
|
| 22 |
-
- `csrc/` — CUDA source files and setup.py for the native extension
|
| 23 |
-
|
| 24 |
-
Build instructions:
|
| 25 |
-
|
| 26 |
-
```bash
|
| 27 |
-
cd /home/mikeb/work/feather/kernels/cuda/flashfftconv/csrc
|
| 28 |
-
|
| 29 |
-
# Edit `csrc/setup.py` first: change the cc_flag line to match your GPU arch
|
| 30 |
-
# (RTX 3060 = 8.6, A100 = 8.0, H100 = 9.0). Example for RTX 3060:
|
| 31 |
-
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86']
|
| 32 |
-
|
| 33 |
-
# Build with the local CUDA toolchain (must match your torch.version.cuda):
|
| 34 |
-
CUDA_HOME=/usr/local/cuda-12.1 .venv/bin/pip install -e .
|
| 35 |
-
```
|
| 36 |
-
|
| 37 |
-
Then install the Python wrappers:
|
| 38 |
-
|
| 39 |
-
```bash
|
| 40 |
-
cd /home/mikeb/work/feather/kernels/cuda/flashfftconv
|
| 41 |
-
.venv/bin/pip install -e .
|
| 42 |
-
```
|
| 43 |
-
|
| 44 |
-
## Runtime usage
|
| 45 |
-
|
| 46 |
-
Once installed, set `HYDRA_HYENA_FLASH_FFT=1` and training will use it.
|
| 47 |
-
`subsystems/hyena_pure.fftconv_ref` auto-detects via `try: import flashfftconv`
|
| 48 |
-
and falls back to pure PyTorch on import failure.
|
| 49 |
-
|
| 50 |
-
## Known caveats
|
| 51 |
-
|
| 52 |
-
- Seqlen must be a power of 2 AND in the supported set: {256, 512, 1024, 2048,
|
| 53 |
-
4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304}.
|
| 54 |
-
For HYDRA, `fft_size = 2 * seq_len` → seq_len in {128, 256, 512, 1024, 2048, ...}.
|
| 55 |
-
- dtype must be fp16 or bf16 (fp32 not supported).
|
| 56 |
-
- GPU arch must be compiled into the extension (see setup.py cc_flag).
|
| 57 |
-
- CUDA toolchain major.minor should match `torch.version.cuda` major (12.x ↔ 12.x).
|
|
|
|
| 1 |
+
# flashfftconv (vendored)
|
| 2 |
+
|
| 3 |
+
Vendored from https://github.com/HazyResearch/flash-fft-conv (Apache 2.0 license).
|
| 4 |
+
|
| 5 |
+
**Upstream commit:** see `UPSTREAM_COMMIT`.
|
| 6 |
+
|
| 7 |
+
## What this is
|
| 8 |
+
|
| 9 |
+
HazyResearch's Monarch-matrix-decomposition FFT convolution CUDA kernel. Provides a
|
| 10 |
+
drop-in replacement for `torch.fft.rfft + complex-mult + irfft` that runs ~2-3x
|
| 11 |
+
faster than cuFFT for the specific power-of-two lengths it supports (256, 512,
|
| 12 |
+
1024, 2048, 4096, 8192, ..., up to 4M).
|
| 13 |
+
|
| 14 |
+
In HYDRA, we use it to accelerate `subsystems/hyena_pure.fftconv_ref`. The
|
| 15 |
+
accelerated path is opt-in via `HYDRA_HYENA_FLASH_FFT=1`; default behavior is
|
| 16 |
+
unchanged (pure PyTorch fallback).
|
| 17 |
+
|
| 18 |
+
## How to build
|
| 19 |
+
|
| 20 |
+
The vendored tree contains:
|
| 21 |
+
- `flashfftconv/` — pure-Python wrappers (imports `monarch_cuda` CUDA extension)
|
| 22 |
+
- `csrc/` — CUDA source files and setup.py for the native extension
|
| 23 |
+
|
| 24 |
+
Build instructions:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
cd /home/mikeb/work/feather/kernels/cuda/flashfftconv/csrc
|
| 28 |
+
|
| 29 |
+
# Edit `csrc/setup.py` first: change the cc_flag line to match your GPU arch
|
| 30 |
+
# (RTX 3060 = 8.6, A100 = 8.0, H100 = 9.0). Example for RTX 3060:
|
| 31 |
+
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86']
|
| 32 |
+
|
| 33 |
+
# Build with the local CUDA toolchain (must match your torch.version.cuda):
|
| 34 |
+
CUDA_HOME=/usr/local/cuda-12.1 .venv/bin/pip install -e .
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Then install the Python wrappers:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
cd /home/mikeb/work/feather/kernels/cuda/flashfftconv
|
| 41 |
+
.venv/bin/pip install -e .
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Runtime usage
|
| 45 |
+
|
| 46 |
+
Once installed, set `HYDRA_HYENA_FLASH_FFT=1` and training will use it.
|
| 47 |
+
`subsystems/hyena_pure.fftconv_ref` auto-detects via `try: import flashfftconv`
|
| 48 |
+
and falls back to pure PyTorch on import failure.
|
| 49 |
+
|
| 50 |
+
## Known caveats
|
| 51 |
+
|
| 52 |
+
- Seqlen must be a power of 2 AND in the supported set: {256, 512, 1024, 2048,
|
| 53 |
+
4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304}.
|
| 54 |
+
For HYDRA, `fft_size = 2 * seq_len` → seq_len in {128, 256, 512, 1024, 2048, ...}.
|
| 55 |
+
- dtype must be fp16 or bf16 (fp32 not supported).
|
| 56 |
+
- GPU arch must be compiled into the extension (see setup.py cc_flag).
|
| 57 |
+
- CUDA toolchain major.minor should match `torch.version.cuda` major (12.x ↔ 12.x).
|
overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
b8771028717f46d5b22cbb8e12833f35033d621b
|
|
|
|
| 1 |
+
b8771028717f46d5b22cbb8e12833f35033d621b
|
overlay/kernels/cuda/flashfftconv/csrc/.gitignore
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
*.npy
|
| 2 |
-
*.json
|
| 3 |
-
*.png
|
| 4 |
-
|
| 5 |
-
*/*.npy
|
| 6 |
-
*/*.json
|
| 7 |
-
*/*.png
|
| 8 |
-
|
| 9 |
-
*.DS_Store
|
| 10 |
*/*.DS_Store
|
|
|
|
| 1 |
+
*.npy
|
| 2 |
+
*.json
|
| 3 |
+
*.png
|
| 4 |
+
|
| 5 |
+
*/*.npy
|
| 6 |
+
*/*.json
|
| 7 |
+
*/*.png
|
| 8 |
+
|
| 9 |
+
*.DS_Store
|
| 10 |
*/*.DS_Store
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h
CHANGED
|
@@ -1,374 +1,374 @@
|
|
| 1 |
-
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
-
|
| 3 |
-
#include <torch/extension.h>
|
| 4 |
-
|
| 5 |
-
#include <vector>
|
| 6 |
-
|
| 7 |
-
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 8 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 9 |
-
#define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
|
| 10 |
-
#define CHECK_INPUT(x) \
|
| 11 |
-
CHECK_CUDA(x); \
|
| 12 |
-
CHECK_CONTIGUOUS(x); \
|
| 13 |
-
CHECK_IS_HALF_OR_BFLOAT(x)
|
| 14 |
-
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
std::vector<torch::Tensor> butterfly_cuda(
|
| 18 |
-
torch::Tensor x,
|
| 19 |
-
torch::Tensor d_f_T,
|
| 20 |
-
torch::Tensor twiddle_factors_real,
|
| 21 |
-
torch::Tensor twiddle_factors_imag,
|
| 22 |
-
std::optional<at::Tensor> x_gate = std::nullopt
|
| 23 |
-
);
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
std::vector<torch::Tensor> butterfly_bf16_cuda(
|
| 27 |
-
torch::Tensor x,
|
| 28 |
-
torch::Tensor d_f_T_real,
|
| 29 |
-
torch::Tensor d_f_T_imag,
|
| 30 |
-
torch::Tensor twiddle_factors_real,
|
| 31 |
-
torch::Tensor twiddle_factors_imag,
|
| 32 |
-
std::optional<at::Tensor> out_gate = std::nullopt
|
| 33 |
-
);
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
std::vector<torch::Tensor> butterfly_padded_cuda(
|
| 37 |
-
torch::Tensor x,
|
| 38 |
-
torch::Tensor d_f_T,
|
| 39 |
-
torch::Tensor twiddle_factors_real,
|
| 40 |
-
torch::Tensor twiddle_factors_imag,
|
| 41 |
-
int M,
|
| 42 |
-
std::optional<at::Tensor> x_gate = std::nullopt
|
| 43 |
-
);
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
|
| 47 |
-
torch::Tensor x,
|
| 48 |
-
torch::Tensor d_f_T_real,
|
| 49 |
-
torch::Tensor d_f_T_imag,
|
| 50 |
-
torch::Tensor twiddle_factors_real,
|
| 51 |
-
torch::Tensor twiddle_factors_imag,
|
| 52 |
-
int M,
|
| 53 |
-
std::optional<at::Tensor> x_gate = std::nullopt
|
| 54 |
-
);
|
| 55 |
-
|
| 56 |
-
torch::Tensor butterfly_ifft_cuda(
|
| 57 |
-
torch::Tensor x_real,
|
| 58 |
-
torch::Tensor x_imag,
|
| 59 |
-
torch::Tensor d_f_T,
|
| 60 |
-
torch::Tensor twiddle_factors_real,
|
| 61 |
-
torch::Tensor twiddle_factors_imag,
|
| 62 |
-
std::optional<at::Tensor> out_gate = std::nullopt
|
| 63 |
-
);
|
| 64 |
-
|
| 65 |
-
torch::Tensor butterfly_ifft_bf16_cuda(
|
| 66 |
-
torch::Tensor x_real,
|
| 67 |
-
torch::Tensor x_imag,
|
| 68 |
-
torch::Tensor d_f_real,
|
| 69 |
-
torch::Tensor d_f_imag,
|
| 70 |
-
torch::Tensor twiddle_factors_real,
|
| 71 |
-
torch::Tensor twiddle_factors_imag,
|
| 72 |
-
std::optional<at::Tensor> x_gate = std::nullopt
|
| 73 |
-
);
|
| 74 |
-
|
| 75 |
-
torch::Tensor butterfly_ifft_padded_cuda(
|
| 76 |
-
torch::Tensor x_real,
|
| 77 |
-
torch::Tensor x_imag,
|
| 78 |
-
torch::Tensor d_f,
|
| 79 |
-
torch::Tensor twiddle_factors_real,
|
| 80 |
-
torch::Tensor twiddle_factors_imag,
|
| 81 |
-
int N,
|
| 82 |
-
std::optional<at::Tensor> out_gate = std::nullopt
|
| 83 |
-
);
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
torch::Tensor butterfly_ifft_padded_bf16_cuda(
|
| 87 |
-
torch::Tensor x_real,
|
| 88 |
-
torch::Tensor x_imag,
|
| 89 |
-
torch::Tensor d_f_real,
|
| 90 |
-
torch::Tensor d_f_imag,
|
| 91 |
-
torch::Tensor twiddle_factors_real,
|
| 92 |
-
torch::Tensor twiddle_factors_imag,
|
| 93 |
-
int N,
|
| 94 |
-
std::optional<at::Tensor> out_gate = std::nullopt
|
| 95 |
-
);
|
| 96 |
-
|
| 97 |
-
std::vector<torch::Tensor> butterfly(
|
| 98 |
-
torch::Tensor x,
|
| 99 |
-
torch::Tensor d_f_T,
|
| 100 |
-
torch::Tensor twiddle_factors_real,
|
| 101 |
-
torch::Tensor twiddle_factors_imag
|
| 102 |
-
){
|
| 103 |
-
CHECK_INPUT(x);
|
| 104 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 105 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
std::vector<torch::Tensor> butterfly_gated(
|
| 112 |
-
torch::Tensor x,
|
| 113 |
-
torch::Tensor d_f_T,
|
| 114 |
-
torch::Tensor twiddle_factors_real,
|
| 115 |
-
torch::Tensor twiddle_factors_imag,
|
| 116 |
-
torch::Tensor x_gate
|
| 117 |
-
){
|
| 118 |
-
CHECK_INPUT(x);
|
| 119 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 120 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 121 |
-
|
| 122 |
-
CHECK_INPUT(x_gate);
|
| 123 |
-
|
| 124 |
-
return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, x_gate);
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
std::vector<torch::Tensor> butterfly_bf16(
|
| 128 |
-
torch::Tensor x,
|
| 129 |
-
torch::Tensor d_f_T_real,
|
| 130 |
-
torch::Tensor d_f_T_imag,
|
| 131 |
-
torch::Tensor twiddle_factors_real,
|
| 132 |
-
torch::Tensor twiddle_factors_imag
|
| 133 |
-
){
|
| 134 |
-
CHECK_INPUT(x);
|
| 135 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 136 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 137 |
-
CHECK_INPUT(d_f_T_real);
|
| 138 |
-
CHECK_INPUT(d_f_T_imag);
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag);
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
std::vector<torch::Tensor> butterfly_gated_bf16(
|
| 145 |
-
torch::Tensor x,
|
| 146 |
-
torch::Tensor d_f_T_real,
|
| 147 |
-
torch::Tensor d_f_T_imag,
|
| 148 |
-
torch::Tensor twiddle_factors_real,
|
| 149 |
-
torch::Tensor twiddle_factors_imag,
|
| 150 |
-
torch::Tensor x_gate
|
| 151 |
-
){
|
| 152 |
-
CHECK_INPUT(x);
|
| 153 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 154 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 155 |
-
CHECK_INPUT(d_f_T_real);
|
| 156 |
-
CHECK_INPUT(d_f_T_imag);
|
| 157 |
-
CHECK_INPUT(x_gate);
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, x_gate);
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
torch::Tensor butterfly_ifft(
|
| 164 |
-
torch::Tensor x_real,
|
| 165 |
-
torch::Tensor x_imag,
|
| 166 |
-
torch::Tensor d_f_T,
|
| 167 |
-
torch::Tensor twiddle_factors_real,
|
| 168 |
-
torch::Tensor twiddle_factors_imag
|
| 169 |
-
){
|
| 170 |
-
CHECK_INPUT(x_real);
|
| 171 |
-
CHECK_INPUT(x_imag);
|
| 172 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 173 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 174 |
-
|
| 175 |
-
return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
torch::Tensor butterfly_ifft_gated(
|
| 180 |
-
torch::Tensor x_real,
|
| 181 |
-
torch::Tensor x_imag,
|
| 182 |
-
torch::Tensor d_f_T,
|
| 183 |
-
torch::Tensor twiddle_factors_real,
|
| 184 |
-
torch::Tensor twiddle_factors_imag,
|
| 185 |
-
torch::Tensor out_gate
|
| 186 |
-
){
|
| 187 |
-
CHECK_INPUT(x_real);
|
| 188 |
-
CHECK_INPUT(x_imag);
|
| 189 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 190 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 191 |
-
CHECK_INPUT(out_gate);
|
| 192 |
-
|
| 193 |
-
return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag, out_gate);
|
| 194 |
-
}
|
| 195 |
-
|
| 196 |
-
torch::Tensor butterfly_ifft_bf16(
|
| 197 |
-
torch::Tensor x_real,
|
| 198 |
-
torch::Tensor x_imag,
|
| 199 |
-
torch::Tensor d_f_real,
|
| 200 |
-
torch::Tensor d_f_imag,
|
| 201 |
-
torch::Tensor twiddle_factors_real,
|
| 202 |
-
torch::Tensor twiddle_factors_imag
|
| 203 |
-
){
|
| 204 |
-
CHECK_INPUT(x_real);
|
| 205 |
-
CHECK_INPUT(x_imag);
|
| 206 |
-
CHECK_INPUT(d_f_real);
|
| 207 |
-
CHECK_INPUT(d_f_imag);
|
| 208 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 209 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag);
|
| 213 |
-
}
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
torch::Tensor butterfly_ifft_gated_bf16(
|
| 217 |
-
torch::Tensor x_real,
|
| 218 |
-
torch::Tensor x_imag,
|
| 219 |
-
torch::Tensor d_f_real,
|
| 220 |
-
torch::Tensor d_f_imag,
|
| 221 |
-
torch::Tensor twiddle_factors_real,
|
| 222 |
-
torch::Tensor twiddle_factors_imag,
|
| 223 |
-
torch::Tensor out_gate
|
| 224 |
-
){
|
| 225 |
-
CHECK_INPUT(x_real);
|
| 226 |
-
CHECK_INPUT(x_imag);
|
| 227 |
-
CHECK_INPUT(d_f_real);
|
| 228 |
-
CHECK_INPUT(d_f_imag);
|
| 229 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 230 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 231 |
-
CHECK_INPUT(out_gate);
|
| 232 |
-
|
| 233 |
-
return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, out_gate);
|
| 234 |
-
}
|
| 235 |
-
|
| 236 |
-
std::vector<torch::Tensor> butterfly_padded(
|
| 237 |
-
torch::Tensor x,
|
| 238 |
-
torch::Tensor d_f_T,
|
| 239 |
-
torch::Tensor twiddle_factors_real,
|
| 240 |
-
torch::Tensor twiddle_factors_imag,
|
| 241 |
-
int M
|
| 242 |
-
){
|
| 243 |
-
CHECK_INPUT(x);
|
| 244 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 245 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M);
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
std::vector<torch::Tensor> butterfly_padded_bf16(
|
| 252 |
-
torch::Tensor x,
|
| 253 |
-
torch::Tensor d_f_T_real,
|
| 254 |
-
torch::Tensor d_f_T_imag,
|
| 255 |
-
torch::Tensor twiddle_factors_real,
|
| 256 |
-
torch::Tensor twiddle_factors_imag,
|
| 257 |
-
int M
|
| 258 |
-
){
|
| 259 |
-
CHECK_INPUT(x);
|
| 260 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 261 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M);
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
std::vector<torch::Tensor> butterfly_padded_gated(
|
| 269 |
-
torch::Tensor x,
|
| 270 |
-
torch::Tensor d_f_T,
|
| 271 |
-
torch::Tensor twiddle_factors_real,
|
| 272 |
-
torch::Tensor twiddle_factors_imag,
|
| 273 |
-
int M,
|
| 274 |
-
torch::Tensor x_gate
|
| 275 |
-
){
|
| 276 |
-
CHECK_INPUT(x);
|
| 277 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 278 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
std::vector<torch::Tensor> butterfly_padded_gated_bf16(
|
| 285 |
-
torch::Tensor x,
|
| 286 |
-
torch::Tensor d_f_T_real,
|
| 287 |
-
torch::Tensor d_f_T_imag,
|
| 288 |
-
torch::Tensor twiddle_factors_real,
|
| 289 |
-
torch::Tensor twiddle_factors_imag,
|
| 290 |
-
int M,
|
| 291 |
-
torch::Tensor x_gate
|
| 292 |
-
){
|
| 293 |
-
CHECK_INPUT(x);
|
| 294 |
-
CHECK_INPUT(d_f_T_real);
|
| 295 |
-
CHECK_INPUT(d_f_T_imag);
|
| 296 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 297 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
torch::Tensor butterfly_ifft_padded(
|
| 304 |
-
torch::Tensor x_real,
|
| 305 |
-
torch::Tensor x_imag,
|
| 306 |
-
torch::Tensor d_f,
|
| 307 |
-
torch::Tensor twiddle_factors_real,
|
| 308 |
-
torch::Tensor twiddle_factors_imag,
|
| 309 |
-
int N
|
| 310 |
-
){
|
| 311 |
-
CHECK_INPUT(x_real);
|
| 312 |
-
CHECK_INPUT(x_imag);
|
| 313 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 314 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 315 |
-
|
| 316 |
-
return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
torch::Tensor butterfly_ifft_padded_gated(
|
| 320 |
-
torch::Tensor x_real,
|
| 321 |
-
torch::Tensor x_imag,
|
| 322 |
-
torch::Tensor d_f,
|
| 323 |
-
torch::Tensor twiddle_factors_real,
|
| 324 |
-
torch::Tensor twiddle_factors_imag,
|
| 325 |
-
int N,
|
| 326 |
-
torch::Tensor out_gate
|
| 327 |
-
){
|
| 328 |
-
CHECK_INPUT(x_real);
|
| 329 |
-
CHECK_INPUT(x_imag);
|
| 330 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 331 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 332 |
-
|
| 333 |
-
return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
|
| 334 |
-
}
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
torch::Tensor butterfly_ifft_padded_bf16(
|
| 338 |
-
torch::Tensor x_real,
|
| 339 |
-
torch::Tensor x_imag,
|
| 340 |
-
torch::Tensor d_f_real,
|
| 341 |
-
torch::Tensor d_f_imag,
|
| 342 |
-
torch::Tensor twiddle_factors_real,
|
| 343 |
-
torch::Tensor twiddle_factors_imag,
|
| 344 |
-
int N
|
| 345 |
-
){
|
| 346 |
-
CHECK_INPUT(x_real);
|
| 347 |
-
CHECK_INPUT(x_imag);
|
| 348 |
-
CHECK_INPUT(d_f_real);
|
| 349 |
-
CHECK_INPUT(d_f_imag);
|
| 350 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 351 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 352 |
-
|
| 353 |
-
return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N);
|
| 354 |
-
}
|
| 355 |
-
|
| 356 |
-
torch::Tensor butterfly_ifft_padded_gated_bf16(
|
| 357 |
-
torch::Tensor x_real,
|
| 358 |
-
torch::Tensor x_imag,
|
| 359 |
-
torch::Tensor d_f_real,
|
| 360 |
-
torch::Tensor d_f_imag,
|
| 361 |
-
torch::Tensor twiddle_factors_real,
|
| 362 |
-
torch::Tensor twiddle_factors_imag,
|
| 363 |
-
int N,
|
| 364 |
-
torch::Tensor out_gate
|
| 365 |
-
){
|
| 366 |
-
CHECK_INPUT(x_real);
|
| 367 |
-
CHECK_INPUT(x_imag);
|
| 368 |
-
CHECK_INPUT(d_f_real);
|
| 369 |
-
CHECK_INPUT(d_f_imag);
|
| 370 |
-
CHECK_INPUT(twiddle_factors_real);
|
| 371 |
-
CHECK_INPUT(twiddle_factors_imag);
|
| 372 |
-
|
| 373 |
-
return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
|
| 374 |
}
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 9 |
+
#define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
|
| 10 |
+
#define CHECK_INPUT(x) \
|
| 11 |
+
CHECK_CUDA(x); \
|
| 12 |
+
CHECK_CONTIGUOUS(x); \
|
| 13 |
+
CHECK_IS_HALF_OR_BFLOAT(x)
|
| 14 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
std::vector<torch::Tensor> butterfly_cuda(
|
| 18 |
+
torch::Tensor x,
|
| 19 |
+
torch::Tensor d_f_T,
|
| 20 |
+
torch::Tensor twiddle_factors_real,
|
| 21 |
+
torch::Tensor twiddle_factors_imag,
|
| 22 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
std::vector<torch::Tensor> butterfly_bf16_cuda(
|
| 27 |
+
torch::Tensor x,
|
| 28 |
+
torch::Tensor d_f_T_real,
|
| 29 |
+
torch::Tensor d_f_T_imag,
|
| 30 |
+
torch::Tensor twiddle_factors_real,
|
| 31 |
+
torch::Tensor twiddle_factors_imag,
|
| 32 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 33 |
+
);
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
std::vector<torch::Tensor> butterfly_padded_cuda(
|
| 37 |
+
torch::Tensor x,
|
| 38 |
+
torch::Tensor d_f_T,
|
| 39 |
+
torch::Tensor twiddle_factors_real,
|
| 40 |
+
torch::Tensor twiddle_factors_imag,
|
| 41 |
+
int M,
|
| 42 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 43 |
+
);
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
|
| 47 |
+
torch::Tensor x,
|
| 48 |
+
torch::Tensor d_f_T_real,
|
| 49 |
+
torch::Tensor d_f_T_imag,
|
| 50 |
+
torch::Tensor twiddle_factors_real,
|
| 51 |
+
torch::Tensor twiddle_factors_imag,
|
| 52 |
+
int M,
|
| 53 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 54 |
+
);
|
| 55 |
+
|
| 56 |
+
torch::Tensor butterfly_ifft_cuda(
|
| 57 |
+
torch::Tensor x_real,
|
| 58 |
+
torch::Tensor x_imag,
|
| 59 |
+
torch::Tensor d_f_T,
|
| 60 |
+
torch::Tensor twiddle_factors_real,
|
| 61 |
+
torch::Tensor twiddle_factors_imag,
|
| 62 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 63 |
+
);
|
| 64 |
+
|
| 65 |
+
torch::Tensor butterfly_ifft_bf16_cuda(
|
| 66 |
+
torch::Tensor x_real,
|
| 67 |
+
torch::Tensor x_imag,
|
| 68 |
+
torch::Tensor d_f_real,
|
| 69 |
+
torch::Tensor d_f_imag,
|
| 70 |
+
torch::Tensor twiddle_factors_real,
|
| 71 |
+
torch::Tensor twiddle_factors_imag,
|
| 72 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
torch::Tensor butterfly_ifft_padded_cuda(
|
| 76 |
+
torch::Tensor x_real,
|
| 77 |
+
torch::Tensor x_imag,
|
| 78 |
+
torch::Tensor d_f,
|
| 79 |
+
torch::Tensor twiddle_factors_real,
|
| 80 |
+
torch::Tensor twiddle_factors_imag,
|
| 81 |
+
int N,
|
| 82 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 83 |
+
);
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
torch::Tensor butterfly_ifft_padded_bf16_cuda(
|
| 87 |
+
torch::Tensor x_real,
|
| 88 |
+
torch::Tensor x_imag,
|
| 89 |
+
torch::Tensor d_f_real,
|
| 90 |
+
torch::Tensor d_f_imag,
|
| 91 |
+
torch::Tensor twiddle_factors_real,
|
| 92 |
+
torch::Tensor twiddle_factors_imag,
|
| 93 |
+
int N,
|
| 94 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 95 |
+
);
|
| 96 |
+
|
| 97 |
+
std::vector<torch::Tensor> butterfly(
|
| 98 |
+
torch::Tensor x,
|
| 99 |
+
torch::Tensor d_f_T,
|
| 100 |
+
torch::Tensor twiddle_factors_real,
|
| 101 |
+
torch::Tensor twiddle_factors_imag
|
| 102 |
+
){
|
| 103 |
+
CHECK_INPUT(x);
|
| 104 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 105 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
std::vector<torch::Tensor> butterfly_gated(
|
| 112 |
+
torch::Tensor x,
|
| 113 |
+
torch::Tensor d_f_T,
|
| 114 |
+
torch::Tensor twiddle_factors_real,
|
| 115 |
+
torch::Tensor twiddle_factors_imag,
|
| 116 |
+
torch::Tensor x_gate
|
| 117 |
+
){
|
| 118 |
+
CHECK_INPUT(x);
|
| 119 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 120 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 121 |
+
|
| 122 |
+
CHECK_INPUT(x_gate);
|
| 123 |
+
|
| 124 |
+
return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, x_gate);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
std::vector<torch::Tensor> butterfly_bf16(
|
| 128 |
+
torch::Tensor x,
|
| 129 |
+
torch::Tensor d_f_T_real,
|
| 130 |
+
torch::Tensor d_f_T_imag,
|
| 131 |
+
torch::Tensor twiddle_factors_real,
|
| 132 |
+
torch::Tensor twiddle_factors_imag
|
| 133 |
+
){
|
| 134 |
+
CHECK_INPUT(x);
|
| 135 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 136 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 137 |
+
CHECK_INPUT(d_f_T_real);
|
| 138 |
+
CHECK_INPUT(d_f_T_imag);
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
std::vector<torch::Tensor> butterfly_gated_bf16(
|
| 145 |
+
torch::Tensor x,
|
| 146 |
+
torch::Tensor d_f_T_real,
|
| 147 |
+
torch::Tensor d_f_T_imag,
|
| 148 |
+
torch::Tensor twiddle_factors_real,
|
| 149 |
+
torch::Tensor twiddle_factors_imag,
|
| 150 |
+
torch::Tensor x_gate
|
| 151 |
+
){
|
| 152 |
+
CHECK_INPUT(x);
|
| 153 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 154 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 155 |
+
CHECK_INPUT(d_f_T_real);
|
| 156 |
+
CHECK_INPUT(d_f_T_imag);
|
| 157 |
+
CHECK_INPUT(x_gate);
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, x_gate);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
torch::Tensor butterfly_ifft(
|
| 164 |
+
torch::Tensor x_real,
|
| 165 |
+
torch::Tensor x_imag,
|
| 166 |
+
torch::Tensor d_f_T,
|
| 167 |
+
torch::Tensor twiddle_factors_real,
|
| 168 |
+
torch::Tensor twiddle_factors_imag
|
| 169 |
+
){
|
| 170 |
+
CHECK_INPUT(x_real);
|
| 171 |
+
CHECK_INPUT(x_imag);
|
| 172 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 173 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 174 |
+
|
| 175 |
+
return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
torch::Tensor butterfly_ifft_gated(
|
| 180 |
+
torch::Tensor x_real,
|
| 181 |
+
torch::Tensor x_imag,
|
| 182 |
+
torch::Tensor d_f_T,
|
| 183 |
+
torch::Tensor twiddle_factors_real,
|
| 184 |
+
torch::Tensor twiddle_factors_imag,
|
| 185 |
+
torch::Tensor out_gate
|
| 186 |
+
){
|
| 187 |
+
CHECK_INPUT(x_real);
|
| 188 |
+
CHECK_INPUT(x_imag);
|
| 189 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 190 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 191 |
+
CHECK_INPUT(out_gate);
|
| 192 |
+
|
| 193 |
+
return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag, out_gate);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
torch::Tensor butterfly_ifft_bf16(
|
| 197 |
+
torch::Tensor x_real,
|
| 198 |
+
torch::Tensor x_imag,
|
| 199 |
+
torch::Tensor d_f_real,
|
| 200 |
+
torch::Tensor d_f_imag,
|
| 201 |
+
torch::Tensor twiddle_factors_real,
|
| 202 |
+
torch::Tensor twiddle_factors_imag
|
| 203 |
+
){
|
| 204 |
+
CHECK_INPUT(x_real);
|
| 205 |
+
CHECK_INPUT(x_imag);
|
| 206 |
+
CHECK_INPUT(d_f_real);
|
| 207 |
+
CHECK_INPUT(d_f_imag);
|
| 208 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 209 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
torch::Tensor butterfly_ifft_gated_bf16(
|
| 217 |
+
torch::Tensor x_real,
|
| 218 |
+
torch::Tensor x_imag,
|
| 219 |
+
torch::Tensor d_f_real,
|
| 220 |
+
torch::Tensor d_f_imag,
|
| 221 |
+
torch::Tensor twiddle_factors_real,
|
| 222 |
+
torch::Tensor twiddle_factors_imag,
|
| 223 |
+
torch::Tensor out_gate
|
| 224 |
+
){
|
| 225 |
+
CHECK_INPUT(x_real);
|
| 226 |
+
CHECK_INPUT(x_imag);
|
| 227 |
+
CHECK_INPUT(d_f_real);
|
| 228 |
+
CHECK_INPUT(d_f_imag);
|
| 229 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 230 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 231 |
+
CHECK_INPUT(out_gate);
|
| 232 |
+
|
| 233 |
+
return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, out_gate);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
std::vector<torch::Tensor> butterfly_padded(
|
| 237 |
+
torch::Tensor x,
|
| 238 |
+
torch::Tensor d_f_T,
|
| 239 |
+
torch::Tensor twiddle_factors_real,
|
| 240 |
+
torch::Tensor twiddle_factors_imag,
|
| 241 |
+
int M
|
| 242 |
+
){
|
| 243 |
+
CHECK_INPUT(x);
|
| 244 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 245 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
std::vector<torch::Tensor> butterfly_padded_bf16(
|
| 252 |
+
torch::Tensor x,
|
| 253 |
+
torch::Tensor d_f_T_real,
|
| 254 |
+
torch::Tensor d_f_T_imag,
|
| 255 |
+
torch::Tensor twiddle_factors_real,
|
| 256 |
+
torch::Tensor twiddle_factors_imag,
|
| 257 |
+
int M
|
| 258 |
+
){
|
| 259 |
+
CHECK_INPUT(x);
|
| 260 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 261 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
std::vector<torch::Tensor> butterfly_padded_gated(
|
| 269 |
+
torch::Tensor x,
|
| 270 |
+
torch::Tensor d_f_T,
|
| 271 |
+
torch::Tensor twiddle_factors_real,
|
| 272 |
+
torch::Tensor twiddle_factors_imag,
|
| 273 |
+
int M,
|
| 274 |
+
torch::Tensor x_gate
|
| 275 |
+
){
|
| 276 |
+
CHECK_INPUT(x);
|
| 277 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 278 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
std::vector<torch::Tensor> butterfly_padded_gated_bf16(
|
| 285 |
+
torch::Tensor x,
|
| 286 |
+
torch::Tensor d_f_T_real,
|
| 287 |
+
torch::Tensor d_f_T_imag,
|
| 288 |
+
torch::Tensor twiddle_factors_real,
|
| 289 |
+
torch::Tensor twiddle_factors_imag,
|
| 290 |
+
int M,
|
| 291 |
+
torch::Tensor x_gate
|
| 292 |
+
){
|
| 293 |
+
CHECK_INPUT(x);
|
| 294 |
+
CHECK_INPUT(d_f_T_real);
|
| 295 |
+
CHECK_INPUT(d_f_T_imag);
|
| 296 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 297 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
torch::Tensor butterfly_ifft_padded(
|
| 304 |
+
torch::Tensor x_real,
|
| 305 |
+
torch::Tensor x_imag,
|
| 306 |
+
torch::Tensor d_f,
|
| 307 |
+
torch::Tensor twiddle_factors_real,
|
| 308 |
+
torch::Tensor twiddle_factors_imag,
|
| 309 |
+
int N
|
| 310 |
+
){
|
| 311 |
+
CHECK_INPUT(x_real);
|
| 312 |
+
CHECK_INPUT(x_imag);
|
| 313 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 314 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 315 |
+
|
| 316 |
+
return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
torch::Tensor butterfly_ifft_padded_gated(
|
| 320 |
+
torch::Tensor x_real,
|
| 321 |
+
torch::Tensor x_imag,
|
| 322 |
+
torch::Tensor d_f,
|
| 323 |
+
torch::Tensor twiddle_factors_real,
|
| 324 |
+
torch::Tensor twiddle_factors_imag,
|
| 325 |
+
int N,
|
| 326 |
+
torch::Tensor out_gate
|
| 327 |
+
){
|
| 328 |
+
CHECK_INPUT(x_real);
|
| 329 |
+
CHECK_INPUT(x_imag);
|
| 330 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 331 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 332 |
+
|
| 333 |
+
return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
torch::Tensor butterfly_ifft_padded_bf16(
|
| 338 |
+
torch::Tensor x_real,
|
| 339 |
+
torch::Tensor x_imag,
|
| 340 |
+
torch::Tensor d_f_real,
|
| 341 |
+
torch::Tensor d_f_imag,
|
| 342 |
+
torch::Tensor twiddle_factors_real,
|
| 343 |
+
torch::Tensor twiddle_factors_imag,
|
| 344 |
+
int N
|
| 345 |
+
){
|
| 346 |
+
CHECK_INPUT(x_real);
|
| 347 |
+
CHECK_INPUT(x_imag);
|
| 348 |
+
CHECK_INPUT(d_f_real);
|
| 349 |
+
CHECK_INPUT(d_f_imag);
|
| 350 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 351 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 352 |
+
|
| 353 |
+
return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
torch::Tensor butterfly_ifft_padded_gated_bf16(
|
| 357 |
+
torch::Tensor x_real,
|
| 358 |
+
torch::Tensor x_imag,
|
| 359 |
+
torch::Tensor d_f_real,
|
| 360 |
+
torch::Tensor d_f_imag,
|
| 361 |
+
torch::Tensor twiddle_factors_real,
|
| 362 |
+
torch::Tensor twiddle_factors_imag,
|
| 363 |
+
int N,
|
| 364 |
+
torch::Tensor out_gate
|
| 365 |
+
){
|
| 366 |
+
CHECK_INPUT(x_real);
|
| 367 |
+
CHECK_INPUT(x_imag);
|
| 368 |
+
CHECK_INPUT(d_f_real);
|
| 369 |
+
CHECK_INPUT(d_f_imag);
|
| 370 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 371 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 372 |
+
|
| 373 |
+
return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
|
| 374 |
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu
CHANGED
|
@@ -1,699 +1,699 @@
|
|
| 1 |
-
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
-
|
| 3 |
-
#include <torch/extension.h>
|
| 4 |
-
|
| 5 |
-
#include <vector>
|
| 6 |
-
#include <stdio.h>
|
| 7 |
-
#include <mma.h>
|
| 8 |
-
#include <cuda_fp16.h>
|
| 9 |
-
#include <cuda_bf16.h>
|
| 10 |
-
#include "shared.h"
|
| 11 |
-
|
| 12 |
-
using namespace nvcuda;
|
| 13 |
-
|
| 14 |
-
__global__ void butterfly_cuda_kernel_64(
|
| 15 |
-
const __half2 *__restrict__ x,
|
| 16 |
-
const __half2 *__restrict__ x_gate,
|
| 17 |
-
const complex_half_t *__restrict__ d_f,
|
| 18 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 19 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 20 |
-
__half2 *__restrict__ out_real,
|
| 21 |
-
__half2 *__restrict__ out_imag,
|
| 22 |
-
uint B,
|
| 23 |
-
uint H,
|
| 24 |
-
int N)
|
| 25 |
-
{
|
| 26 |
-
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 27 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 28 |
-
int idx;
|
| 29 |
-
int shared_offset;
|
| 30 |
-
const int B_Y = blockDim.y;
|
| 31 |
-
const int n = N / B_Y;
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
extern __shared__ half x_shared[];
|
| 35 |
-
half *d_f_real = &x_shared[N * N];
|
| 36 |
-
half *d_f_imag = &d_f_real[N * N];
|
| 37 |
-
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 38 |
-
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 39 |
-
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 40 |
-
half *out_imag_shared = &out_real_shared[N * N];
|
| 41 |
-
|
| 42 |
-
// #pragma unroll
|
| 43 |
-
for (int i = 0; i < n; i++)
|
| 44 |
-
{
|
| 45 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 46 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 47 |
-
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 48 |
-
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 49 |
-
|
| 50 |
-
// #pragma unroll
|
| 51 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
|
| 52 |
-
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 53 |
-
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 54 |
-
|
| 55 |
-
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 56 |
-
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
__half2 tmp_real, tmp_imag;
|
| 60 |
-
|
| 61 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
|
| 62 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 63 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 64 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
|
| 65 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[4][4];
|
| 66 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 67 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
|
| 68 |
-
|
| 69 |
-
__syncthreads();
|
| 70 |
-
|
| 71 |
-
for (int i = 0; i < 4; i++)
|
| 72 |
-
{
|
| 73 |
-
wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
|
| 74 |
-
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
|
| 75 |
-
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 76 |
-
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
for (int t = 0; t < 16; t++)
|
| 80 |
-
{
|
| 81 |
-
|
| 82 |
-
for (int i = 0; i < n; i++)
|
| 83 |
-
{
|
| 84 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 85 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 86 |
-
if(x_gate != nullptr){
|
| 87 |
-
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 88 |
-
}else{
|
| 89 |
-
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 90 |
-
}
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
__syncthreads();
|
| 94 |
-
|
| 95 |
-
for (int i = 0; i < 4; i++)
|
| 96 |
-
{
|
| 97 |
-
for (int j = 0; j < 4; j++)
|
| 98 |
-
{
|
| 99 |
-
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
#pragma unroll
|
| 104 |
-
for (int j = 0; j < 4; j++)
|
| 105 |
-
{
|
| 106 |
-
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 107 |
-
|
| 108 |
-
for (int k = 0; k < 4; k++)
|
| 109 |
-
{
|
| 110 |
-
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
#pragma unroll
|
| 115 |
-
|
| 116 |
-
for (int j = 0; j < 4; j++)
|
| 117 |
-
{
|
| 118 |
-
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 119 |
-
|
| 120 |
-
for (int k = 0; k < 4; k++)
|
| 121 |
-
{
|
| 122 |
-
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 123 |
-
}
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
#pragma unroll
|
| 127 |
-
for (int j = 0; j < 4; j++)
|
| 128 |
-
{
|
| 129 |
-
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 130 |
-
{
|
| 131 |
-
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 132 |
-
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 133 |
-
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 134 |
-
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 138 |
-
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
__syncthreads();
|
| 142 |
-
|
| 143 |
-
#pragma unroll
|
| 144 |
-
for (int i = 0; i < n; i++)
|
| 145 |
-
{
|
| 146 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 147 |
-
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 148 |
-
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
__syncthreads();
|
| 152 |
-
}
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
__global__ void butterfly_cuda_kernel_32(
|
| 156 |
-
const __half2 *__restrict__ x,
|
| 157 |
-
const __half2 *__restrict__ x_gate,
|
| 158 |
-
const complex_half_t *__restrict__ d_f,
|
| 159 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 160 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 161 |
-
__half2 *__restrict__ out_real,
|
| 162 |
-
__half2 *__restrict__ out_imag,
|
| 163 |
-
uint B,
|
| 164 |
-
uint H,
|
| 165 |
-
int N)
|
| 166 |
-
{
|
| 167 |
-
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 168 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 169 |
-
int idx;
|
| 170 |
-
|
| 171 |
-
int shared_offset;
|
| 172 |
-
const int B_Y = blockDim.y;
|
| 173 |
-
const int n = N / B_Y;
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
__shared__ half x_shared[32 * 64];
|
| 177 |
-
__shared__ half d_f_real[32 * 32];
|
| 178 |
-
__shared__ half d_f_imag[32 * 32];
|
| 179 |
-
__shared__ half twiddles_real_shared[32 * 64];
|
| 180 |
-
__shared__ half twiddles_imag_shared[32 * 64];
|
| 181 |
-
__shared__ half out_real_shared[32 * 64];
|
| 182 |
-
__shared__ half out_imag_shared[32 * 64];
|
| 183 |
-
|
| 184 |
-
// #pragma unroll
|
| 185 |
-
for (int i = 0; i < n; i++)
|
| 186 |
-
{
|
| 187 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 188 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 189 |
-
if(x_gate == nullptr){
|
| 190 |
-
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 191 |
-
}else{
|
| 192 |
-
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 193 |
-
}
|
| 194 |
-
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 195 |
-
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 196 |
-
|
| 197 |
-
// #pragma unroll
|
| 198 |
-
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 199 |
-
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
__syncthreads();
|
| 203 |
-
|
| 204 |
-
if (threadIdx.y < N / 16)
|
| 205 |
-
{
|
| 206 |
-
__half2 tmp_real, tmp_imag;
|
| 207 |
-
|
| 208 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 209 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 210 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 211 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 212 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2][2];
|
| 213 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 214 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
|
| 215 |
-
|
| 216 |
-
int t = threadIdx.y * 32;
|
| 217 |
-
|
| 218 |
-
for (int i = 0; i < 2; i++)
|
| 219 |
-
{
|
| 220 |
-
for (int j = 0; j < 2; j++)
|
| 221 |
-
{
|
| 222 |
-
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 223 |
-
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 224 |
-
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 225 |
-
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 226 |
-
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 227 |
-
}
|
| 228 |
-
}
|
| 229 |
-
|
| 230 |
-
#pragma unroll
|
| 231 |
-
for (int i = 0; i < 2; i++)
|
| 232 |
-
{
|
| 233 |
-
for (int j = 0; j < 2; j++)
|
| 234 |
-
{
|
| 235 |
-
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 236 |
-
|
| 237 |
-
for (int k = 0; k < 2; k++)
|
| 238 |
-
{
|
| 239 |
-
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 240 |
-
}
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
#pragma unroll
|
| 245 |
-
for (int i = 0; i < 2; i++)
|
| 246 |
-
{
|
| 247 |
-
for (int j = 0; j < 2; j++)
|
| 248 |
-
{
|
| 249 |
-
wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
|
| 250 |
-
|
| 251 |
-
for (int k = 0; k < 2; k++)
|
| 252 |
-
{
|
| 253 |
-
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 254 |
-
}
|
| 255 |
-
}
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
#pragma unroll
|
| 259 |
-
for (int i = 0; i < 2; i++)
|
| 260 |
-
{
|
| 261 |
-
for (int j = 0; j < 2; j++)
|
| 262 |
-
{
|
| 263 |
-
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 264 |
-
{
|
| 265 |
-
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
|
| 266 |
-
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
|
| 267 |
-
reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
|
| 268 |
-
reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 272 |
-
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 273 |
-
}
|
| 274 |
-
}
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
__syncthreads();
|
| 278 |
-
|
| 279 |
-
#pragma unroll
|
| 280 |
-
for (int i = 0; i < n; i++)
|
| 281 |
-
{
|
| 282 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 283 |
-
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 284 |
-
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
__global__ void butterfly_cuda_kernel_128(
|
| 289 |
-
const __half2 *__restrict__ x,
|
| 290 |
-
const __half2 *__restrict__ x_gate,
|
| 291 |
-
const complex_half_t *__restrict__ d_f,
|
| 292 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 293 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 294 |
-
__half2 *__restrict__ out_real,
|
| 295 |
-
__half2 *__restrict__ out_imag,
|
| 296 |
-
uint B,
|
| 297 |
-
uint H,
|
| 298 |
-
int N)
|
| 299 |
-
{
|
| 300 |
-
const int offset = blockIdx.y * H * 128 * 32 * gridDim.x * 2 + blockIdx.z * 16 * 128 * 32 * gridDim.x * 2 + blockIdx.x * 64 + threadIdx.x;
|
| 301 |
-
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 302 |
-
int idx;
|
| 303 |
-
|
| 304 |
-
int shared_offset;
|
| 305 |
-
const int B_Y = blockDim.y;
|
| 306 |
-
const int n = N / B_Y;
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
extern __shared__ half shared_real[];
|
| 310 |
-
half *shared_imag = &shared_real[128 * 128];
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
|
| 314 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 315 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 316 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
|
| 317 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[8][8];
|
| 318 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 319 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
|
| 320 |
-
|
| 321 |
-
for (int i = 0; i < n; i++)
|
| 322 |
-
{
|
| 323 |
-
for(int j=0; j< 4; j++){
|
| 324 |
-
shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
|
| 325 |
-
shared_real[shared_offset] = d_f[shared_offset].real();
|
| 326 |
-
shared_imag[shared_offset] = d_f[shared_offset].imag();
|
| 327 |
-
}
|
| 328 |
-
}
|
| 329 |
-
|
| 330 |
-
__syncthreads();
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
for (int i = 0; i < 8; i++){
|
| 334 |
-
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 335 |
-
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
__syncthreads();
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
for (int i = 0; i < n; i++)
|
| 344 |
-
{
|
| 345 |
-
for(int j=0; j< 2; j++){
|
| 346 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 347 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 348 |
-
reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 349 |
-
reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 350 |
-
}
|
| 351 |
-
}
|
| 352 |
-
|
| 353 |
-
__syncthreads();
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
for (int i = 0; i < 8; i++){
|
| 357 |
-
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 358 |
-
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
__syncthreads();
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
for(int t=0; t< 16; t++){
|
| 365 |
-
for (int i = 0; i < n; i++)
|
| 366 |
-
{
|
| 367 |
-
for(int j=0; j< 2; j++){
|
| 368 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 369 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 370 |
-
if(x_gate != nullptr){
|
| 371 |
-
reinterpret_cast<__half2*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 372 |
-
}else{
|
| 373 |
-
reinterpret_cast<__half2*>(shared_real)[shared_offset] = x[offset + idx];
|
| 374 |
-
}
|
| 375 |
-
|
| 376 |
-
}
|
| 377 |
-
}
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
__syncthreads();
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
for (int i = 0; i < 8; i++)
|
| 384 |
-
{
|
| 385 |
-
for (int j = 0; j < 8; j++)
|
| 386 |
-
{
|
| 387 |
-
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 388 |
-
}
|
| 389 |
-
}
|
| 390 |
-
|
| 391 |
-
__syncthreads();
|
| 392 |
-
|
| 393 |
-
#pragma unroll
|
| 394 |
-
for (int j = 0; j < 8; j++)
|
| 395 |
-
{
|
| 396 |
-
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 397 |
-
|
| 398 |
-
for (int k = 0; k < 8; k++)
|
| 399 |
-
{
|
| 400 |
-
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 401 |
-
}
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
#pragma unroll
|
| 405 |
-
|
| 406 |
-
for (int j = 0; j < 8; j++)
|
| 407 |
-
{
|
| 408 |
-
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 409 |
-
|
| 410 |
-
for (int k = 0; k < 8; k++)
|
| 411 |
-
{
|
| 412 |
-
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 413 |
-
}
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
__half2 tmp_real, tmp_imag;
|
| 417 |
-
#pragma unroll
|
| 418 |
-
for (int j = 0; j < 8; j++)
|
| 419 |
-
{
|
| 420 |
-
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 421 |
-
{
|
| 422 |
-
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 423 |
-
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 424 |
-
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 425 |
-
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 426 |
-
}
|
| 427 |
-
|
| 428 |
-
wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 429 |
-
wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 430 |
-
}
|
| 431 |
-
|
| 432 |
-
__syncthreads();
|
| 433 |
-
|
| 434 |
-
#pragma unroll
|
| 435 |
-
for (int i = 0; i < n; i++)
|
| 436 |
-
{
|
| 437 |
-
for(int j=0; j< 2; j++){
|
| 438 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 439 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 440 |
-
out_real[offset + idx] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
|
| 441 |
-
out_imag[offset + idx] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
|
| 442 |
-
}
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
__syncthreads();
|
| 446 |
-
}
|
| 447 |
-
}
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
__global__ void butterfly_cuda_kernel_16(
|
| 451 |
-
const __half2 *__restrict__ x,
|
| 452 |
-
const __half2 *__restrict__ x_gate,
|
| 453 |
-
const complex_half_t *__restrict__ d_f,
|
| 454 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 455 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 456 |
-
__half2 *__restrict__ out_real,
|
| 457 |
-
__half2 *__restrict__ out_imag,
|
| 458 |
-
uint B,
|
| 459 |
-
uint H,
|
| 460 |
-
int N)
|
| 461 |
-
{
|
| 462 |
-
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 463 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 464 |
-
int idx;
|
| 465 |
-
|
| 466 |
-
int shared_offset;
|
| 467 |
-
const int B_Y = blockDim.y;
|
| 468 |
-
const int n = N / B_Y;
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
__shared__ half x_shared[16 * 64];
|
| 472 |
-
__shared__ half d_f_real[16 * 16];
|
| 473 |
-
__shared__ half d_f_imag[16 * 16];
|
| 474 |
-
__shared__ half twiddles_real_shared[16 * 64];
|
| 475 |
-
__shared__ half twiddles_imag_shared[16 * 64];
|
| 476 |
-
__shared__ half out_real_shared[16 * 64];
|
| 477 |
-
__shared__ half out_imag_shared[16 * 64];
|
| 478 |
-
|
| 479 |
-
// #pragma unroll
|
| 480 |
-
for (int i = 0; i < n; i++)
|
| 481 |
-
{
|
| 482 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 483 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 484 |
-
|
| 485 |
-
if(x_gate != NULL)
|
| 486 |
-
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 487 |
-
else
|
| 488 |
-
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 489 |
-
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 490 |
-
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 491 |
-
|
| 492 |
-
// #pragma unroll
|
| 493 |
-
|
| 494 |
-
if(threadIdx.x < 16 ){
|
| 495 |
-
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 496 |
-
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 497 |
-
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 498 |
-
}
|
| 499 |
-
}
|
| 500 |
-
|
| 501 |
-
__syncthreads();
|
| 502 |
-
|
| 503 |
-
if (threadIdx.y < 4)
|
| 504 |
-
{
|
| 505 |
-
__half2 tmp_real, tmp_imag;
|
| 506 |
-
|
| 507 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 508 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 509 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 510 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 511 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
|
| 512 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 513 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
|
| 514 |
-
|
| 515 |
-
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 516 |
-
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 517 |
-
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 518 |
-
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 519 |
-
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 536 |
-
{
|
| 537 |
-
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
|
| 538 |
-
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
|
| 539 |
-
reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
|
| 540 |
-
reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
|
| 541 |
-
}
|
| 542 |
-
|
| 543 |
-
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 544 |
-
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 545 |
-
}
|
| 546 |
-
|
| 547 |
-
__syncthreads();
|
| 548 |
-
|
| 549 |
-
#pragma unroll
|
| 550 |
-
for (int i = 0; i < n; i++)
|
| 551 |
-
{
|
| 552 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 553 |
-
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 554 |
-
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 555 |
-
}
|
| 556 |
-
}
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
std::vector<torch::Tensor> butterfly_cuda(
|
| 560 |
-
torch::Tensor x,
|
| 561 |
-
torch::Tensor d_f,
|
| 562 |
-
torch::Tensor twiddle_factors_real,
|
| 563 |
-
torch::Tensor twiddle_factors_imag,
|
| 564 |
-
std::optional<at::Tensor> x_gate = std::nullopt)
|
| 565 |
-
{
|
| 566 |
-
|
| 567 |
-
uint B = x.size(0);
|
| 568 |
-
uint H = x.size(1);
|
| 569 |
-
// uint m = x.size(1);
|
| 570 |
-
|
| 571 |
-
// const int TILE_SIZE = 16;
|
| 572 |
-
uint N = x.size(2);
|
| 573 |
-
uint M = x.size(3);
|
| 574 |
-
dim3 gridDim;
|
| 575 |
-
dim3 blockDim;
|
| 576 |
-
|
| 577 |
-
gridDim.y = B;
|
| 578 |
-
gridDim.z = H;
|
| 579 |
-
|
| 580 |
-
torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
|
| 581 |
-
torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
|
| 582 |
-
|
| 583 |
-
//set blockDims
|
| 584 |
-
switch(N){
|
| 585 |
-
case 128:
|
| 586 |
-
blockDim.x = 32;
|
| 587 |
-
blockDim.y = 8;
|
| 588 |
-
break;
|
| 589 |
-
default:
|
| 590 |
-
blockDim.x = 32;
|
| 591 |
-
blockDim.y = 4;
|
| 592 |
-
break;
|
| 593 |
-
}
|
| 594 |
-
|
| 595 |
-
//set gridDim.x
|
| 596 |
-
switch(N){
|
| 597 |
-
case 128:
|
| 598 |
-
switch (M){
|
| 599 |
-
case 16384:
|
| 600 |
-
gridDim.x = 128;
|
| 601 |
-
break;
|
| 602 |
-
case 8192:
|
| 603 |
-
gridDim.x = 64;
|
| 604 |
-
break;
|
| 605 |
-
case 4096:
|
| 606 |
-
gridDim.x = 32;
|
| 607 |
-
break;
|
| 608 |
-
default:
|
| 609 |
-
gridDim.x = 256;
|
| 610 |
-
break;
|
| 611 |
-
}
|
| 612 |
-
break;
|
| 613 |
-
default:
|
| 614 |
-
switch (M){
|
| 615 |
-
case 16384:
|
| 616 |
-
gridDim.x = 256;
|
| 617 |
-
break;
|
| 618 |
-
case 8192:
|
| 619 |
-
gridDim.x = 128;
|
| 620 |
-
break;
|
| 621 |
-
case 4096:
|
| 622 |
-
gridDim.x = 64;
|
| 623 |
-
break;
|
| 624 |
-
default:
|
| 625 |
-
gridDim.x = 512;
|
| 626 |
-
break;
|
| 627 |
-
}
|
| 628 |
-
break;
|
| 629 |
-
}
|
| 630 |
-
|
| 631 |
-
switch (N)
|
| 632 |
-
{
|
| 633 |
-
case 16:
|
| 634 |
-
butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 635 |
-
static_cast<__half2 *>(x.data_ptr()),
|
| 636 |
-
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 637 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 638 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 639 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 640 |
-
static_cast<__half2 *>(out_real.data_ptr()),
|
| 641 |
-
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 642 |
-
B,
|
| 643 |
-
H,
|
| 644 |
-
N);
|
| 645 |
-
break;
|
| 646 |
-
case 32:
|
| 647 |
-
butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 648 |
-
static_cast<__half2 *>(x.data_ptr()),
|
| 649 |
-
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 650 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 651 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 652 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 653 |
-
static_cast<__half2 *>(out_real.data_ptr()),
|
| 654 |
-
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 655 |
-
B,
|
| 656 |
-
H,
|
| 657 |
-
N);
|
| 658 |
-
break;
|
| 659 |
-
|
| 660 |
-
case 64:
|
| 661 |
-
gridDim.z = H / 16;
|
| 662 |
-
cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 663 |
-
|
| 664 |
-
butterfly_cuda_kernel_64<<<gridDim, blockDim, 57344>>>(
|
| 665 |
-
static_cast<__half2 *>(x.data_ptr()),
|
| 666 |
-
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 667 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 668 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 669 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 670 |
-
static_cast<__half2 *>(out_real.data_ptr()),
|
| 671 |
-
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 672 |
-
B,
|
| 673 |
-
H,
|
| 674 |
-
N);
|
| 675 |
-
break;
|
| 676 |
-
case 128:
|
| 677 |
-
gridDim.z = H / 16;
|
| 678 |
-
cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 679 |
-
|
| 680 |
-
butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
|
| 681 |
-
static_cast<__half2 *>(x.data_ptr()),
|
| 682 |
-
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 683 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 684 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 685 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 686 |
-
static_cast<__half2 *>(out_real.data_ptr()),
|
| 687 |
-
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 688 |
-
B,
|
| 689 |
-
H,
|
| 690 |
-
N);
|
| 691 |
-
break;
|
| 692 |
-
|
| 693 |
-
default:
|
| 694 |
-
printf("Not yet implemented \n");
|
| 695 |
-
break;
|
| 696 |
-
}
|
| 697 |
-
|
| 698 |
-
return {out_real, out_imag};
|
| 699 |
}
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include "shared.h"
|
| 11 |
+
|
| 12 |
+
using namespace nvcuda;
|
| 13 |
+
|
| 14 |
+
__global__ void butterfly_cuda_kernel_64(
|
| 15 |
+
const __half2 *__restrict__ x,
|
| 16 |
+
const __half2 *__restrict__ x_gate,
|
| 17 |
+
const complex_half_t *__restrict__ d_f,
|
| 18 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 19 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 20 |
+
__half2 *__restrict__ out_real,
|
| 21 |
+
__half2 *__restrict__ out_imag,
|
| 22 |
+
uint B,
|
| 23 |
+
uint H,
|
| 24 |
+
int N)
|
| 25 |
+
{
|
| 26 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 27 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 28 |
+
int idx;
|
| 29 |
+
int shared_offset;
|
| 30 |
+
const int B_Y = blockDim.y;
|
| 31 |
+
const int n = N / B_Y;
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
extern __shared__ half x_shared[];
|
| 35 |
+
half *d_f_real = &x_shared[N * N];
|
| 36 |
+
half *d_f_imag = &d_f_real[N * N];
|
| 37 |
+
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 38 |
+
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 39 |
+
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 40 |
+
half *out_imag_shared = &out_real_shared[N * N];
|
| 41 |
+
|
| 42 |
+
// #pragma unroll
|
| 43 |
+
for (int i = 0; i < n; i++)
|
| 44 |
+
{
|
| 45 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 46 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 47 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 48 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 49 |
+
|
| 50 |
+
// #pragma unroll
|
| 51 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
|
| 52 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 53 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 54 |
+
|
| 55 |
+
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 56 |
+
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
__half2 tmp_real, tmp_imag;
|
| 60 |
+
|
| 61 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
|
| 62 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 63 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 64 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
|
| 65 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[4][4];
|
| 66 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 67 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
|
| 68 |
+
|
| 69 |
+
__syncthreads();
|
| 70 |
+
|
| 71 |
+
for (int i = 0; i < 4; i++)
|
| 72 |
+
{
|
| 73 |
+
wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
|
| 74 |
+
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
|
| 75 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 76 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
for (int t = 0; t < 16; t++)
|
| 80 |
+
{
|
| 81 |
+
|
| 82 |
+
for (int i = 0; i < n; i++)
|
| 83 |
+
{
|
| 84 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 85 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 86 |
+
if(x_gate != nullptr){
|
| 87 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 88 |
+
}else{
|
| 89 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
__syncthreads();
|
| 94 |
+
|
| 95 |
+
for (int i = 0; i < 4; i++)
|
| 96 |
+
{
|
| 97 |
+
for (int j = 0; j < 4; j++)
|
| 98 |
+
{
|
| 99 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
#pragma unroll
|
| 104 |
+
for (int j = 0; j < 4; j++)
|
| 105 |
+
{
|
| 106 |
+
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 107 |
+
|
| 108 |
+
for (int k = 0; k < 4; k++)
|
| 109 |
+
{
|
| 110 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#pragma unroll
|
| 115 |
+
|
| 116 |
+
for (int j = 0; j < 4; j++)
|
| 117 |
+
{
|
| 118 |
+
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 119 |
+
|
| 120 |
+
for (int k = 0; k < 4; k++)
|
| 121 |
+
{
|
| 122 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int j = 0; j < 4; j++)
|
| 128 |
+
{
|
| 129 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 130 |
+
{
|
| 131 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 132 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 133 |
+
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 134 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 138 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
__syncthreads();
|
| 142 |
+
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < n; i++)
|
| 145 |
+
{
|
| 146 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 147 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 148 |
+
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__syncthreads();
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
__global__ void butterfly_cuda_kernel_32(
|
| 156 |
+
const __half2 *__restrict__ x,
|
| 157 |
+
const __half2 *__restrict__ x_gate,
|
| 158 |
+
const complex_half_t *__restrict__ d_f,
|
| 159 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 160 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 161 |
+
__half2 *__restrict__ out_real,
|
| 162 |
+
__half2 *__restrict__ out_imag,
|
| 163 |
+
uint B,
|
| 164 |
+
uint H,
|
| 165 |
+
int N)
|
| 166 |
+
{
|
| 167 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 168 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 169 |
+
int idx;
|
| 170 |
+
|
| 171 |
+
int shared_offset;
|
| 172 |
+
const int B_Y = blockDim.y;
|
| 173 |
+
const int n = N / B_Y;
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
__shared__ half x_shared[32 * 64];
|
| 177 |
+
__shared__ half d_f_real[32 * 32];
|
| 178 |
+
__shared__ half d_f_imag[32 * 32];
|
| 179 |
+
__shared__ half twiddles_real_shared[32 * 64];
|
| 180 |
+
__shared__ half twiddles_imag_shared[32 * 64];
|
| 181 |
+
__shared__ half out_real_shared[32 * 64];
|
| 182 |
+
__shared__ half out_imag_shared[32 * 64];
|
| 183 |
+
|
| 184 |
+
// #pragma unroll
|
| 185 |
+
for (int i = 0; i < n; i++)
|
| 186 |
+
{
|
| 187 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 188 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 189 |
+
if(x_gate == nullptr){
|
| 190 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 191 |
+
}else{
|
| 192 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 193 |
+
}
|
| 194 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 195 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 196 |
+
|
| 197 |
+
// #pragma unroll
|
| 198 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 199 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
__syncthreads();
|
| 203 |
+
|
| 204 |
+
if (threadIdx.y < N / 16)
|
| 205 |
+
{
|
| 206 |
+
__half2 tmp_real, tmp_imag;
|
| 207 |
+
|
| 208 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 209 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 210 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 211 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 212 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2][2];
|
| 213 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 214 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
|
| 215 |
+
|
| 216 |
+
int t = threadIdx.y * 32;
|
| 217 |
+
|
| 218 |
+
for (int i = 0; i < 2; i++)
|
| 219 |
+
{
|
| 220 |
+
for (int j = 0; j < 2; j++)
|
| 221 |
+
{
|
| 222 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 223 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 224 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 225 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 226 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
#pragma unroll
|
| 231 |
+
for (int i = 0; i < 2; i++)
|
| 232 |
+
{
|
| 233 |
+
for (int j = 0; j < 2; j++)
|
| 234 |
+
{
|
| 235 |
+
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 236 |
+
|
| 237 |
+
for (int k = 0; k < 2; k++)
|
| 238 |
+
{
|
| 239 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
#pragma unroll
|
| 245 |
+
for (int i = 0; i < 2; i++)
|
| 246 |
+
{
|
| 247 |
+
for (int j = 0; j < 2; j++)
|
| 248 |
+
{
|
| 249 |
+
wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
|
| 250 |
+
|
| 251 |
+
for (int k = 0; k < 2; k++)
|
| 252 |
+
{
|
| 253 |
+
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
#pragma unroll
|
| 259 |
+
for (int i = 0; i < 2; i++)
|
| 260 |
+
{
|
| 261 |
+
for (int j = 0; j < 2; j++)
|
| 262 |
+
{
|
| 263 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 264 |
+
{
|
| 265 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
|
| 266 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
|
| 267 |
+
reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
|
| 268 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 272 |
+
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
__syncthreads();
|
| 278 |
+
|
| 279 |
+
#pragma unroll
|
| 280 |
+
for (int i = 0; i < n; i++)
|
| 281 |
+
{
|
| 282 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 283 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 284 |
+
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
__global__ void butterfly_cuda_kernel_128(
|
| 289 |
+
const __half2 *__restrict__ x,
|
| 290 |
+
const __half2 *__restrict__ x_gate,
|
| 291 |
+
const complex_half_t *__restrict__ d_f,
|
| 292 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 293 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 294 |
+
__half2 *__restrict__ out_real,
|
| 295 |
+
__half2 *__restrict__ out_imag,
|
| 296 |
+
uint B,
|
| 297 |
+
uint H,
|
| 298 |
+
int N)
|
| 299 |
+
{
|
| 300 |
+
const int offset = blockIdx.y * H * 128 * 32 * gridDim.x * 2 + blockIdx.z * 16 * 128 * 32 * gridDim.x * 2 + blockIdx.x * 64 + threadIdx.x;
|
| 301 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 302 |
+
int idx;
|
| 303 |
+
|
| 304 |
+
int shared_offset;
|
| 305 |
+
const int B_Y = blockDim.y;
|
| 306 |
+
const int n = N / B_Y;
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
extern __shared__ half shared_real[];
|
| 310 |
+
half *shared_imag = &shared_real[128 * 128];
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
|
| 314 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 315 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 316 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
|
| 317 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[8][8];
|
| 318 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 319 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
|
| 320 |
+
|
| 321 |
+
for (int i = 0; i < n; i++)
|
| 322 |
+
{
|
| 323 |
+
for(int j=0; j< 4; j++){
|
| 324 |
+
shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
|
| 325 |
+
shared_real[shared_offset] = d_f[shared_offset].real();
|
| 326 |
+
shared_imag[shared_offset] = d_f[shared_offset].imag();
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
__syncthreads();
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
for (int i = 0; i < 8; i++){
|
| 334 |
+
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 335 |
+
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
__syncthreads();
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
for (int i = 0; i < n; i++)
|
| 344 |
+
{
|
| 345 |
+
for(int j=0; j< 2; j++){
|
| 346 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 347 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 348 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 349 |
+
reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
__syncthreads();
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
for (int i = 0; i < 8; i++){
|
| 357 |
+
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 358 |
+
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
__syncthreads();
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
for(int t=0; t< 16; t++){
|
| 365 |
+
for (int i = 0; i < n; i++)
|
| 366 |
+
{
|
| 367 |
+
for(int j=0; j< 2; j++){
|
| 368 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 369 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 370 |
+
if(x_gate != nullptr){
|
| 371 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 372 |
+
}else{
|
| 373 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = x[offset + idx];
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
__syncthreads();
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
for (int i = 0; i < 8; i++)
|
| 384 |
+
{
|
| 385 |
+
for (int j = 0; j < 8; j++)
|
| 386 |
+
{
|
| 387 |
+
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
__syncthreads();
|
| 392 |
+
|
| 393 |
+
#pragma unroll
|
| 394 |
+
for (int j = 0; j < 8; j++)
|
| 395 |
+
{
|
| 396 |
+
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 397 |
+
|
| 398 |
+
for (int k = 0; k < 8; k++)
|
| 399 |
+
{
|
| 400 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
#pragma unroll
|
| 405 |
+
|
| 406 |
+
for (int j = 0; j < 8; j++)
|
| 407 |
+
{
|
| 408 |
+
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 409 |
+
|
| 410 |
+
for (int k = 0; k < 8; k++)
|
| 411 |
+
{
|
| 412 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
__half2 tmp_real, tmp_imag;
|
| 417 |
+
#pragma unroll
|
| 418 |
+
for (int j = 0; j < 8; j++)
|
| 419 |
+
{
|
| 420 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 421 |
+
{
|
| 422 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 423 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 424 |
+
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 425 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 429 |
+
wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
__syncthreads();
|
| 433 |
+
|
| 434 |
+
#pragma unroll
|
| 435 |
+
for (int i = 0; i < n; i++)
|
| 436 |
+
{
|
| 437 |
+
for(int j=0; j< 2; j++){
|
| 438 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 439 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 440 |
+
out_real[offset + idx] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
|
| 441 |
+
out_imag[offset + idx] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
__syncthreads();
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
__global__ void butterfly_cuda_kernel_16(
|
| 451 |
+
const __half2 *__restrict__ x,
|
| 452 |
+
const __half2 *__restrict__ x_gate,
|
| 453 |
+
const complex_half_t *__restrict__ d_f,
|
| 454 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 455 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 456 |
+
__half2 *__restrict__ out_real,
|
| 457 |
+
__half2 *__restrict__ out_imag,
|
| 458 |
+
uint B,
|
| 459 |
+
uint H,
|
| 460 |
+
int N)
|
| 461 |
+
{
|
| 462 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 463 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 464 |
+
int idx;
|
| 465 |
+
|
| 466 |
+
int shared_offset;
|
| 467 |
+
const int B_Y = blockDim.y;
|
| 468 |
+
const int n = N / B_Y;
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
__shared__ half x_shared[16 * 64];
|
| 472 |
+
__shared__ half d_f_real[16 * 16];
|
| 473 |
+
__shared__ half d_f_imag[16 * 16];
|
| 474 |
+
__shared__ half twiddles_real_shared[16 * 64];
|
| 475 |
+
__shared__ half twiddles_imag_shared[16 * 64];
|
| 476 |
+
__shared__ half out_real_shared[16 * 64];
|
| 477 |
+
__shared__ half out_imag_shared[16 * 64];
|
| 478 |
+
|
| 479 |
+
// #pragma unroll
|
| 480 |
+
for (int i = 0; i < n; i++)
|
| 481 |
+
{
|
| 482 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 483 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 484 |
+
|
| 485 |
+
if(x_gate != NULL)
|
| 486 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 487 |
+
else
|
| 488 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 489 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 490 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 491 |
+
|
| 492 |
+
// #pragma unroll
|
| 493 |
+
|
| 494 |
+
if(threadIdx.x < 16 ){
|
| 495 |
+
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 496 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 497 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
__syncthreads();
|
| 502 |
+
|
| 503 |
+
if (threadIdx.y < 4)
|
| 504 |
+
{
|
| 505 |
+
__half2 tmp_real, tmp_imag;
|
| 506 |
+
|
| 507 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 508 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 509 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 510 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 511 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
|
| 512 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 513 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
|
| 514 |
+
|
| 515 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 516 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 517 |
+
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 518 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 519 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 536 |
+
{
|
| 537 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
|
| 538 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
|
| 539 |
+
reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
|
| 540 |
+
reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 544 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
__syncthreads();
|
| 548 |
+
|
| 549 |
+
#pragma unroll
|
| 550 |
+
for (int i = 0; i < n; i++)
|
| 551 |
+
{
|
| 552 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 553 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 554 |
+
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
std::vector<torch::Tensor> butterfly_cuda(
|
| 560 |
+
torch::Tensor x,
|
| 561 |
+
torch::Tensor d_f,
|
| 562 |
+
torch::Tensor twiddle_factors_real,
|
| 563 |
+
torch::Tensor twiddle_factors_imag,
|
| 564 |
+
std::optional<at::Tensor> x_gate = std::nullopt)
|
| 565 |
+
{
|
| 566 |
+
|
| 567 |
+
uint B = x.size(0);
|
| 568 |
+
uint H = x.size(1);
|
| 569 |
+
// uint m = x.size(1);
|
| 570 |
+
|
| 571 |
+
// const int TILE_SIZE = 16;
|
| 572 |
+
uint N = x.size(2);
|
| 573 |
+
uint M = x.size(3);
|
| 574 |
+
dim3 gridDim;
|
| 575 |
+
dim3 blockDim;
|
| 576 |
+
|
| 577 |
+
gridDim.y = B;
|
| 578 |
+
gridDim.z = H;
|
| 579 |
+
|
| 580 |
+
torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
|
| 581 |
+
torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
|
| 582 |
+
|
| 583 |
+
//set blockDims
|
| 584 |
+
switch(N){
|
| 585 |
+
case 128:
|
| 586 |
+
blockDim.x = 32;
|
| 587 |
+
blockDim.y = 8;
|
| 588 |
+
break;
|
| 589 |
+
default:
|
| 590 |
+
blockDim.x = 32;
|
| 591 |
+
blockDim.y = 4;
|
| 592 |
+
break;
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
//set gridDim.x
|
| 596 |
+
switch(N){
|
| 597 |
+
case 128:
|
| 598 |
+
switch (M){
|
| 599 |
+
case 16384:
|
| 600 |
+
gridDim.x = 128;
|
| 601 |
+
break;
|
| 602 |
+
case 8192:
|
| 603 |
+
gridDim.x = 64;
|
| 604 |
+
break;
|
| 605 |
+
case 4096:
|
| 606 |
+
gridDim.x = 32;
|
| 607 |
+
break;
|
| 608 |
+
default:
|
| 609 |
+
gridDim.x = 256;
|
| 610 |
+
break;
|
| 611 |
+
}
|
| 612 |
+
break;
|
| 613 |
+
default:
|
| 614 |
+
switch (M){
|
| 615 |
+
case 16384:
|
| 616 |
+
gridDim.x = 256;
|
| 617 |
+
break;
|
| 618 |
+
case 8192:
|
| 619 |
+
gridDim.x = 128;
|
| 620 |
+
break;
|
| 621 |
+
case 4096:
|
| 622 |
+
gridDim.x = 64;
|
| 623 |
+
break;
|
| 624 |
+
default:
|
| 625 |
+
gridDim.x = 512;
|
| 626 |
+
break;
|
| 627 |
+
}
|
| 628 |
+
break;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
switch (N)
|
| 632 |
+
{
|
| 633 |
+
case 16:
|
| 634 |
+
butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 635 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 636 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 637 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 638 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 639 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 640 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 641 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 642 |
+
B,
|
| 643 |
+
H,
|
| 644 |
+
N);
|
| 645 |
+
break;
|
| 646 |
+
case 32:
|
| 647 |
+
butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 648 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 649 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 650 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 651 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 652 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 653 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 654 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 655 |
+
B,
|
| 656 |
+
H,
|
| 657 |
+
N);
|
| 658 |
+
break;
|
| 659 |
+
|
| 660 |
+
case 64:
|
| 661 |
+
gridDim.z = H / 16;
|
| 662 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 663 |
+
|
| 664 |
+
butterfly_cuda_kernel_64<<<gridDim, blockDim, 57344>>>(
|
| 665 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 666 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 667 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 668 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 669 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 670 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 671 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 672 |
+
B,
|
| 673 |
+
H,
|
| 674 |
+
N);
|
| 675 |
+
break;
|
| 676 |
+
case 128:
|
| 677 |
+
gridDim.z = H / 16;
|
| 678 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 679 |
+
|
| 680 |
+
butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
|
| 681 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 682 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 683 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 684 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 685 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 686 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 687 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 688 |
+
B,
|
| 689 |
+
H,
|
| 690 |
+
N);
|
| 691 |
+
break;
|
| 692 |
+
|
| 693 |
+
default:
|
| 694 |
+
printf("Not yet implemented \n");
|
| 695 |
+
break;
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
return {out_real, out_imag};
|
| 699 |
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu
CHANGED
|
@@ -1,725 +1,725 @@
|
|
| 1 |
-
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
-
|
| 3 |
-
#include <torch/extension.h>
|
| 4 |
-
|
| 5 |
-
#include <vector>
|
| 6 |
-
#include <stdio.h>
|
| 7 |
-
#include <mma.h>
|
| 8 |
-
#include <cuda_runtime.h>
|
| 9 |
-
#include <cuda_fp16.h>
|
| 10 |
-
#include <cuda_bf16.h>
|
| 11 |
-
#include "shared.h"
|
| 12 |
-
|
| 13 |
-
using namespace nvcuda;
|
| 14 |
-
|
| 15 |
-
__global__ void butterfly_cuda_kernel_64(
|
| 16 |
-
const __nv_bfloat162 *__restrict__ x,
|
| 17 |
-
const __nv_bfloat162 *__restrict__ x_gate,
|
| 18 |
-
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
-
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
-
__nv_bfloat162 *__restrict__ out_imag,
|
| 24 |
-
uint B,
|
| 25 |
-
uint H,
|
| 26 |
-
int N)
|
| 27 |
-
{
|
| 28 |
-
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 29 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 30 |
-
int idx;
|
| 31 |
-
int shared_offset;
|
| 32 |
-
const int B_Y = blockDim.y;
|
| 33 |
-
const int n = N / B_Y;
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
extern __shared__ __nv_bfloat16 x_shared[];
|
| 37 |
-
__nv_bfloat16 *d_f_real_shared = &x_shared[N * N];
|
| 38 |
-
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 39 |
-
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 40 |
-
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 41 |
-
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 42 |
-
float *out_imag_shared = &out_real_shared[N * N];
|
| 43 |
-
|
| 44 |
-
// #pragma unroll
|
| 45 |
-
for (int i = 0; i < n; i++)
|
| 46 |
-
{
|
| 47 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 48 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 49 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 50 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 51 |
-
|
| 52 |
-
// #pragma unroll
|
| 53 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 54 |
-
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 55 |
-
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
float2 tmp_real, tmp_imag;
|
| 59 |
-
|
| 60 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
|
| 61 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 62 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 63 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
|
| 64 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
|
| 65 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 66 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
|
| 67 |
-
|
| 68 |
-
__syncthreads();
|
| 69 |
-
|
| 70 |
-
for (int i = 0; i < 4; i++)
|
| 71 |
-
{
|
| 72 |
-
wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 73 |
-
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 74 |
-
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 75 |
-
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
for (int t = 0; t < 16; t++)
|
| 79 |
-
{
|
| 80 |
-
|
| 81 |
-
for (int i = 0; i < n; i++)
|
| 82 |
-
{
|
| 83 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 84 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 85 |
-
if(x_gate != nullptr){
|
| 86 |
-
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 87 |
-
}else{
|
| 88 |
-
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 89 |
-
}
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
__syncthreads();
|
| 93 |
-
|
| 94 |
-
for (int i = 0; i < 4; i++)
|
| 95 |
-
{
|
| 96 |
-
for (int j = 0; j < 4; j++)
|
| 97 |
-
{
|
| 98 |
-
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 99 |
-
}
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
#pragma unroll
|
| 103 |
-
for (int j = 0; j < 4; j++)
|
| 104 |
-
{
|
| 105 |
-
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 106 |
-
|
| 107 |
-
for (int k = 0; k < 4; k++)
|
| 108 |
-
{
|
| 109 |
-
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 110 |
-
}
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
#pragma unroll
|
| 114 |
-
|
| 115 |
-
for (int j = 0; j < 4; j++)
|
| 116 |
-
{
|
| 117 |
-
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 118 |
-
|
| 119 |
-
for (int k = 0; k < 4; k++)
|
| 120 |
-
{
|
| 121 |
-
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 122 |
-
}
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
#pragma unroll
|
| 126 |
-
for (int j = 0; j < 4; j++)
|
| 127 |
-
{
|
| 128 |
-
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 129 |
-
{
|
| 130 |
-
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 131 |
-
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 132 |
-
|
| 133 |
-
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 134 |
-
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 138 |
-
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
__syncthreads();
|
| 142 |
-
|
| 143 |
-
#pragma unroll
|
| 144 |
-
for (int i = 0; i < n; i++)
|
| 145 |
-
{
|
| 146 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 147 |
-
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 148 |
-
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
__syncthreads();
|
| 152 |
-
}
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
__global__ void butterfly_cuda_kernel_32(
|
| 156 |
-
const __nv_bfloat162 *__restrict__ x,
|
| 157 |
-
const __nv_bfloat162 *__restrict__ x_gate,
|
| 158 |
-
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 159 |
-
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 160 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 161 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 162 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 163 |
-
__nv_bfloat162 *__restrict__ out_imag,
|
| 164 |
-
uint B,
|
| 165 |
-
uint H,
|
| 166 |
-
int N)
|
| 167 |
-
{
|
| 168 |
-
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 169 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 170 |
-
int idx;
|
| 171 |
-
|
| 172 |
-
int shared_offset;
|
| 173 |
-
const int B_Y = blockDim.y;
|
| 174 |
-
const int n = N / B_Y;
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
__shared__ __nv_bfloat16 x_shared[32 * 64];
|
| 178 |
-
__shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
|
| 179 |
-
__shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
|
| 180 |
-
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 181 |
-
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 182 |
-
__shared__ float out_real_shared[32 * 64];
|
| 183 |
-
__shared__ float out_imag_shared[32 * 64];
|
| 184 |
-
|
| 185 |
-
// #pragma unroll
|
| 186 |
-
for (int i = 0; i < n; i++)
|
| 187 |
-
{
|
| 188 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 189 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 190 |
-
if(x_gate != nullptr){
|
| 191 |
-
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 192 |
-
}else{
|
| 193 |
-
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 194 |
-
}
|
| 195 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 196 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 197 |
-
|
| 198 |
-
// #pragma unroll
|
| 199 |
-
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 200 |
-
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
__syncthreads();
|
| 204 |
-
|
| 205 |
-
if (threadIdx.y < N / 16)
|
| 206 |
-
{
|
| 207 |
-
float2 tmp_real, tmp_imag;
|
| 208 |
-
|
| 209 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 210 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 211 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 212 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 213 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[2][2];
|
| 214 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 215 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
|
| 216 |
-
|
| 217 |
-
int t = threadIdx.y * 32;
|
| 218 |
-
|
| 219 |
-
for (int i = 0; i < 2; i++)
|
| 220 |
-
{
|
| 221 |
-
for (int j = 0; j < 2; j++)
|
| 222 |
-
{
|
| 223 |
-
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 224 |
-
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 225 |
-
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 226 |
-
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 227 |
-
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 228 |
-
}
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
#pragma unroll
|
| 232 |
-
for (int i = 0; i < 2; i++)
|
| 233 |
-
{
|
| 234 |
-
for (int j = 0; j < 2; j++)
|
| 235 |
-
{
|
| 236 |
-
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 237 |
-
|
| 238 |
-
for (int k = 0; k < 2; k++)
|
| 239 |
-
{
|
| 240 |
-
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
-
}
|
| 244 |
-
|
| 245 |
-
#pragma unroll
|
| 246 |
-
for (int i = 0; i < 2; i++)
|
| 247 |
-
{
|
| 248 |
-
for (int j = 0; j < 2; j++)
|
| 249 |
-
{
|
| 250 |
-
wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
|
| 251 |
-
|
| 252 |
-
for (int k = 0; k < 2; k++)
|
| 253 |
-
{
|
| 254 |
-
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 255 |
-
}
|
| 256 |
-
}
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
#pragma unroll
|
| 260 |
-
for (int i = 0; i < 2; i++)
|
| 261 |
-
{
|
| 262 |
-
for (int j = 0; j < 2; j++)
|
| 263 |
-
{
|
| 264 |
-
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 265 |
-
{
|
| 266 |
-
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
|
| 267 |
-
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
|
| 268 |
-
reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
|
| 269 |
-
reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
|
| 270 |
-
}
|
| 271 |
-
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 272 |
-
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 273 |
-
}
|
| 274 |
-
}
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
__syncthreads();
|
| 278 |
-
|
| 279 |
-
#pragma unroll
|
| 280 |
-
for (int i = 0; i < n; i++)
|
| 281 |
-
{
|
| 282 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 283 |
-
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 284 |
-
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
__global__ void butterfly_cuda_kernel_128(
|
| 289 |
-
const __nv_bfloat162 *__restrict__ x,
|
| 290 |
-
const __nv_bfloat162 *__restrict__ x_gate,
|
| 291 |
-
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 292 |
-
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 293 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 294 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 295 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 296 |
-
__nv_bfloat162 *__restrict__ out_imag,
|
| 297 |
-
uint B,
|
| 298 |
-
uint H,
|
| 299 |
-
int N)
|
| 300 |
-
{
|
| 301 |
-
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 302 |
-
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 303 |
-
int idx;
|
| 304 |
-
|
| 305 |
-
int shared_offset;
|
| 306 |
-
const int B_Y = blockDim.y;
|
| 307 |
-
const int n = N / B_Y;
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
extern __shared__ __nv_bfloat16 shared_real[];
|
| 311 |
-
__nv_bfloat16 *shared_imag = &shared_real[128 * 128];
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
|
| 315 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 316 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 317 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
|
| 318 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[8][8];
|
| 319 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 320 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
|
| 321 |
-
|
| 322 |
-
for (int i = 0; i < n; i++)
|
| 323 |
-
{
|
| 324 |
-
for(int j=0; j< 2; j++){
|
| 325 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 326 |
-
reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
|
| 327 |
-
reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
|
| 328 |
-
}
|
| 329 |
-
}
|
| 330 |
-
|
| 331 |
-
__syncthreads();
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
for (int i = 0; i < 8; i++){
|
| 335 |
-
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 336 |
-
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
__syncthreads();
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
for (int i = 0; i < n; i++)
|
| 345 |
-
{
|
| 346 |
-
for(int j=0; j< 2; j++){
|
| 347 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 348 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 349 |
-
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 350 |
-
reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 351 |
-
}
|
| 352 |
-
}
|
| 353 |
-
|
| 354 |
-
__syncthreads();
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
for (int i = 0; i < 8; i++){
|
| 358 |
-
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 359 |
-
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 360 |
-
}
|
| 361 |
-
|
| 362 |
-
__syncthreads();
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
for(int t=0; t< 16; t++){
|
| 366 |
-
for (int i = 0; i < n; i++)
|
| 367 |
-
{
|
| 368 |
-
for(int j=0; j< 2; j++){
|
| 369 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 370 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 371 |
-
if(x_gate != nullptr){
|
| 372 |
-
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 373 |
-
}else{
|
| 374 |
-
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = x[offset + idx];
|
| 375 |
-
}
|
| 376 |
-
}
|
| 377 |
-
}
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
__syncthreads();
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
for (int i = 0; i < 8; i++)
|
| 384 |
-
{
|
| 385 |
-
for (int j = 0; j < 8; j++)
|
| 386 |
-
{
|
| 387 |
-
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 388 |
-
}
|
| 389 |
-
}
|
| 390 |
-
|
| 391 |
-
__syncthreads();
|
| 392 |
-
|
| 393 |
-
#pragma unroll
|
| 394 |
-
for (int j = 0; j < 8; j++)
|
| 395 |
-
{
|
| 396 |
-
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 397 |
-
|
| 398 |
-
for (int k = 0; k < 8; k++)
|
| 399 |
-
{
|
| 400 |
-
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 401 |
-
}
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
#pragma unroll
|
| 405 |
-
|
| 406 |
-
for (int j = 0; j < 8; j++)
|
| 407 |
-
{
|
| 408 |
-
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 409 |
-
|
| 410 |
-
for (int k = 0; k < 8; k++)
|
| 411 |
-
{
|
| 412 |
-
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 413 |
-
}
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
float2 tmp_real, tmp_imag;
|
| 417 |
-
#pragma unroll
|
| 418 |
-
for (int j = 0; j < 8; j++)
|
| 419 |
-
{
|
| 420 |
-
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 421 |
-
{
|
| 422 |
-
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 423 |
-
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 424 |
-
|
| 425 |
-
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 426 |
-
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 427 |
-
}
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
for (int j = 0; j < 8; j++)
|
| 431 |
-
{
|
| 432 |
-
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 433 |
-
}
|
| 434 |
-
|
| 435 |
-
__syncthreads();
|
| 436 |
-
|
| 437 |
-
#pragma unroll
|
| 438 |
-
for (int i = 0; i < n; i++)
|
| 439 |
-
{
|
| 440 |
-
for(int j=0; j< 2; j++){
|
| 441 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 442 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 443 |
-
out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 444 |
-
}
|
| 445 |
-
}
|
| 446 |
-
|
| 447 |
-
__syncthreads();
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
for (int j = 0; j < 8; j++)
|
| 451 |
-
{
|
| 452 |
-
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
__syncthreads();
|
| 456 |
-
|
| 457 |
-
#pragma unroll
|
| 458 |
-
for (int i = 0; i < n; i++)
|
| 459 |
-
{
|
| 460 |
-
for(int j=0; j< 2; j++){
|
| 461 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 462 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 463 |
-
out_imag[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 464 |
-
}
|
| 465 |
-
}
|
| 466 |
-
}
|
| 467 |
-
}
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
__global__ void butterfly_cuda_kernel_16(
|
| 471 |
-
const __nv_bfloat162 *__restrict__ x,
|
| 472 |
-
const __nv_bfloat162 *__restrict__ x_gate,
|
| 473 |
-
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 474 |
-
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 475 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 476 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 477 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 478 |
-
__nv_bfloat162 *__restrict__ out_imag,
|
| 479 |
-
uint B,
|
| 480 |
-
uint H,
|
| 481 |
-
int N)
|
| 482 |
-
{
|
| 483 |
-
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 484 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 485 |
-
int idx;
|
| 486 |
-
|
| 487 |
-
int shared_offset;
|
| 488 |
-
const int B_Y = blockDim.y;
|
| 489 |
-
const int n = N / B_Y;
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
__shared__ __nv_bfloat16 x_shared[16 * 64];
|
| 493 |
-
__shared__ __nv_bfloat16 d_f_real_shared[16 * 16];
|
| 494 |
-
__shared__ __nv_bfloat16 d_f_imag_shared[16 * 16];
|
| 495 |
-
__shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
|
| 496 |
-
__shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
|
| 497 |
-
__shared__ float out_real_shared[16 * 64];
|
| 498 |
-
__shared__ float out_imag_shared[16 * 64];
|
| 499 |
-
|
| 500 |
-
// #pragma unroll
|
| 501 |
-
for (int i = 0; i < n; i++)
|
| 502 |
-
{
|
| 503 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 504 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 505 |
-
if(x_gate != nullptr){
|
| 506 |
-
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 507 |
-
}else{
|
| 508 |
-
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 509 |
-
}
|
| 510 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 511 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 512 |
-
|
| 513 |
-
// #pragma unroll
|
| 514 |
-
if(threadIdx.x < 16 ){
|
| 515 |
-
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 516 |
-
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 517 |
-
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 518 |
-
}
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
__syncthreads();
|
| 522 |
-
|
| 523 |
-
if (threadIdx.y < 4)
|
| 524 |
-
{
|
| 525 |
-
float2 tmp_real, tmp_imag;
|
| 526 |
-
|
| 527 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 528 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 529 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 530 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 531 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
|
| 532 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 533 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
|
| 534 |
-
|
| 535 |
-
wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
|
| 536 |
-
wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
|
| 537 |
-
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 538 |
-
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 539 |
-
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
wmma::fill_fragment(acc_frag_imag, 0.0f);
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
#pragma unroll
|
| 557 |
-
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 558 |
-
{
|
| 559 |
-
tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
|
| 560 |
-
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
|
| 561 |
-
reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
|
| 562 |
-
reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
|
| 563 |
-
}
|
| 564 |
-
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 565 |
-
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 566 |
-
|
| 567 |
-
}
|
| 568 |
-
__syncthreads();
|
| 569 |
-
|
| 570 |
-
#pragma unroll
|
| 571 |
-
for (int i = 0; i < n; i++)
|
| 572 |
-
{
|
| 573 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 574 |
-
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 575 |
-
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 576 |
-
}
|
| 577 |
-
}
|
| 578 |
-
|
| 579 |
-
std::vector<torch::Tensor> butterfly_bf16_cuda(
|
| 580 |
-
torch::Tensor x,
|
| 581 |
-
torch::Tensor d_f_real,
|
| 582 |
-
torch::Tensor d_f_imag,
|
| 583 |
-
torch::Tensor twiddle_factors_real,
|
| 584 |
-
torch::Tensor twiddle_factors_imag,
|
| 585 |
-
std::optional<at::Tensor> x_gate = std::nullopt
|
| 586 |
-
)
|
| 587 |
-
{
|
| 588 |
-
|
| 589 |
-
uint B = x.size(0);
|
| 590 |
-
uint H = x.size(1);
|
| 591 |
-
// uint m = x.size(1);
|
| 592 |
-
|
| 593 |
-
// const int TILE_SIZE = 16;
|
| 594 |
-
uint N = x.size(2);
|
| 595 |
-
uint M = x.size(3);
|
| 596 |
-
dim3 gridDim;
|
| 597 |
-
dim3 blockDim;
|
| 598 |
-
|
| 599 |
-
gridDim.y = B;
|
| 600 |
-
gridDim.z = H;
|
| 601 |
-
|
| 602 |
-
torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
|
| 603 |
-
torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
|
| 604 |
-
|
| 605 |
-
//set blockDims
|
| 606 |
-
switch(N){
|
| 607 |
-
case 128:
|
| 608 |
-
blockDim.x = 32;
|
| 609 |
-
blockDim.y = 8;
|
| 610 |
-
break;
|
| 611 |
-
default:
|
| 612 |
-
blockDim.x = 32;
|
| 613 |
-
blockDim.y = 4;
|
| 614 |
-
break;
|
| 615 |
-
}
|
| 616 |
-
|
| 617 |
-
//set gridDim.x
|
| 618 |
-
switch(N){
|
| 619 |
-
case 128:
|
| 620 |
-
switch (M){
|
| 621 |
-
case 16384:
|
| 622 |
-
gridDim.x = 128;
|
| 623 |
-
break;
|
| 624 |
-
case 8192:
|
| 625 |
-
gridDim.x = 64;
|
| 626 |
-
break;
|
| 627 |
-
case 4096:
|
| 628 |
-
gridDim.x = 32;
|
| 629 |
-
break;
|
| 630 |
-
default:
|
| 631 |
-
gridDim.x = 256;
|
| 632 |
-
break;
|
| 633 |
-
}
|
| 634 |
-
break;
|
| 635 |
-
default:
|
| 636 |
-
switch (M){
|
| 637 |
-
case 16384:
|
| 638 |
-
gridDim.x = 256;
|
| 639 |
-
break;
|
| 640 |
-
case 8192:
|
| 641 |
-
gridDim.x = 128;
|
| 642 |
-
break;
|
| 643 |
-
case 4096:
|
| 644 |
-
gridDim.x = 64;
|
| 645 |
-
break;
|
| 646 |
-
default:
|
| 647 |
-
gridDim.x = 512;
|
| 648 |
-
break;
|
| 649 |
-
}
|
| 650 |
-
break;
|
| 651 |
-
}
|
| 652 |
-
|
| 653 |
-
switch (N)
|
| 654 |
-
{
|
| 655 |
-
case 16:
|
| 656 |
-
butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 657 |
-
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 658 |
-
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 659 |
-
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 660 |
-
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 661 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 662 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 663 |
-
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 664 |
-
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 665 |
-
B,
|
| 666 |
-
H,
|
| 667 |
-
N);
|
| 668 |
-
break;
|
| 669 |
-
case 32:
|
| 670 |
-
butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 671 |
-
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 672 |
-
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 673 |
-
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 674 |
-
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 675 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 676 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 677 |
-
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 678 |
-
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 679 |
-
B,
|
| 680 |
-
H,
|
| 681 |
-
N);
|
| 682 |
-
break;
|
| 683 |
-
|
| 684 |
-
case 64:
|
| 685 |
-
gridDim.z = H / 16;
|
| 686 |
-
cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 687 |
-
|
| 688 |
-
butterfly_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
|
| 689 |
-
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 690 |
-
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 691 |
-
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 692 |
-
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 693 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 694 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 695 |
-
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 696 |
-
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 697 |
-
B,
|
| 698 |
-
H,
|
| 699 |
-
N);
|
| 700 |
-
break;
|
| 701 |
-
case 128:
|
| 702 |
-
gridDim.z = H / 16;
|
| 703 |
-
cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 704 |
-
|
| 705 |
-
butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
|
| 706 |
-
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 707 |
-
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 708 |
-
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 709 |
-
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 710 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 711 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 712 |
-
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 713 |
-
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 714 |
-
B,
|
| 715 |
-
H,
|
| 716 |
-
N);
|
| 717 |
-
break;
|
| 718 |
-
|
| 719 |
-
default:
|
| 720 |
-
printf("Not yet implemented \n");
|
| 721 |
-
break;
|
| 722 |
-
}
|
| 723 |
-
|
| 724 |
-
return {out_real, out_imag};
|
| 725 |
}
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_runtime.h>
|
| 9 |
+
#include <cuda_fp16.h>
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#include "shared.h"
|
| 12 |
+
|
| 13 |
+
using namespace nvcuda;
|
| 14 |
+
|
| 15 |
+
__global__ void butterfly_cuda_kernel_64(
|
| 16 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 17 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 18 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 24 |
+
uint B,
|
| 25 |
+
uint H,
|
| 26 |
+
int N)
|
| 27 |
+
{
|
| 28 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 29 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 30 |
+
int idx;
|
| 31 |
+
int shared_offset;
|
| 32 |
+
const int B_Y = blockDim.y;
|
| 33 |
+
const int n = N / B_Y;
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
extern __shared__ __nv_bfloat16 x_shared[];
|
| 37 |
+
__nv_bfloat16 *d_f_real_shared = &x_shared[N * N];
|
| 38 |
+
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 39 |
+
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 40 |
+
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 41 |
+
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 42 |
+
float *out_imag_shared = &out_real_shared[N * N];
|
| 43 |
+
|
| 44 |
+
// #pragma unroll
|
| 45 |
+
for (int i = 0; i < n; i++)
|
| 46 |
+
{
|
| 47 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 48 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 49 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 50 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 51 |
+
|
| 52 |
+
// #pragma unroll
|
| 53 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 54 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 55 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
float2 tmp_real, tmp_imag;
|
| 59 |
+
|
| 60 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
|
| 61 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 62 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 63 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
|
| 64 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
|
| 65 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 66 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
|
| 67 |
+
|
| 68 |
+
__syncthreads();
|
| 69 |
+
|
| 70 |
+
for (int i = 0; i < 4; i++)
|
| 71 |
+
{
|
| 72 |
+
wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 73 |
+
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 74 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 75 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
for (int t = 0; t < 16; t++)
|
| 79 |
+
{
|
| 80 |
+
|
| 81 |
+
for (int i = 0; i < n; i++)
|
| 82 |
+
{
|
| 83 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 84 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 85 |
+
if(x_gate != nullptr){
|
| 86 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 87 |
+
}else{
|
| 88 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
__syncthreads();
|
| 93 |
+
|
| 94 |
+
for (int i = 0; i < 4; i++)
|
| 95 |
+
{
|
| 96 |
+
for (int j = 0; j < 4; j++)
|
| 97 |
+
{
|
| 98 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
#pragma unroll
|
| 103 |
+
for (int j = 0; j < 4; j++)
|
| 104 |
+
{
|
| 105 |
+
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 106 |
+
|
| 107 |
+
for (int k = 0; k < 4; k++)
|
| 108 |
+
{
|
| 109 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
#pragma unroll
|
| 114 |
+
|
| 115 |
+
for (int j = 0; j < 4; j++)
|
| 116 |
+
{
|
| 117 |
+
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 118 |
+
|
| 119 |
+
for (int k = 0; k < 4; k++)
|
| 120 |
+
{
|
| 121 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for (int j = 0; j < 4; j++)
|
| 127 |
+
{
|
| 128 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 129 |
+
{
|
| 130 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 131 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 132 |
+
|
| 133 |
+
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 134 |
+
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 138 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
__syncthreads();
|
| 142 |
+
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < n; i++)
|
| 145 |
+
{
|
| 146 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 147 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 148 |
+
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__syncthreads();
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
__global__ void butterfly_cuda_kernel_32(
|
| 156 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 157 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 158 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 159 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 160 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 161 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 162 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 163 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 164 |
+
uint B,
|
| 165 |
+
uint H,
|
| 166 |
+
int N)
|
| 167 |
+
{
|
| 168 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 169 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 170 |
+
int idx;
|
| 171 |
+
|
| 172 |
+
int shared_offset;
|
| 173 |
+
const int B_Y = blockDim.y;
|
| 174 |
+
const int n = N / B_Y;
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
__shared__ __nv_bfloat16 x_shared[32 * 64];
|
| 178 |
+
__shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
|
| 179 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
|
| 180 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 181 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 182 |
+
__shared__ float out_real_shared[32 * 64];
|
| 183 |
+
__shared__ float out_imag_shared[32 * 64];
|
| 184 |
+
|
| 185 |
+
// #pragma unroll
|
| 186 |
+
for (int i = 0; i < n; i++)
|
| 187 |
+
{
|
| 188 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 189 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 190 |
+
if(x_gate != nullptr){
|
| 191 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 192 |
+
}else{
|
| 193 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 194 |
+
}
|
| 195 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 196 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 197 |
+
|
| 198 |
+
// #pragma unroll
|
| 199 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 200 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
__syncthreads();
|
| 204 |
+
|
| 205 |
+
if (threadIdx.y < N / 16)
|
| 206 |
+
{
|
| 207 |
+
float2 tmp_real, tmp_imag;
|
| 208 |
+
|
| 209 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 210 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 211 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 212 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 213 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[2][2];
|
| 214 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 215 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
|
| 216 |
+
|
| 217 |
+
int t = threadIdx.y * 32;
|
| 218 |
+
|
| 219 |
+
for (int i = 0; i < 2; i++)
|
| 220 |
+
{
|
| 221 |
+
for (int j = 0; j < 2; j++)
|
| 222 |
+
{
|
| 223 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 224 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 225 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 226 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 227 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
#pragma unroll
|
| 232 |
+
for (int i = 0; i < 2; i++)
|
| 233 |
+
{
|
| 234 |
+
for (int j = 0; j < 2; j++)
|
| 235 |
+
{
|
| 236 |
+
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 237 |
+
|
| 238 |
+
for (int k = 0; k < 2; k++)
|
| 239 |
+
{
|
| 240 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
#pragma unroll
|
| 246 |
+
for (int i = 0; i < 2; i++)
|
| 247 |
+
{
|
| 248 |
+
for (int j = 0; j < 2; j++)
|
| 249 |
+
{
|
| 250 |
+
wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
|
| 251 |
+
|
| 252 |
+
for (int k = 0; k < 2; k++)
|
| 253 |
+
{
|
| 254 |
+
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
#pragma unroll
|
| 260 |
+
for (int i = 0; i < 2; i++)
|
| 261 |
+
{
|
| 262 |
+
for (int j = 0; j < 2; j++)
|
| 263 |
+
{
|
| 264 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 265 |
+
{
|
| 266 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
|
| 267 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
|
| 268 |
+
reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
|
| 269 |
+
reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
|
| 270 |
+
}
|
| 271 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 272 |
+
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
__syncthreads();
|
| 278 |
+
|
| 279 |
+
#pragma unroll
|
| 280 |
+
for (int i = 0; i < n; i++)
|
| 281 |
+
{
|
| 282 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 283 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 284 |
+
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
__global__ void butterfly_cuda_kernel_128(
|
| 289 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 290 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 291 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 292 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 293 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 294 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 295 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 296 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 297 |
+
uint B,
|
| 298 |
+
uint H,
|
| 299 |
+
int N)
|
| 300 |
+
{
|
| 301 |
+
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 302 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 303 |
+
int idx;
|
| 304 |
+
|
| 305 |
+
int shared_offset;
|
| 306 |
+
const int B_Y = blockDim.y;
|
| 307 |
+
const int n = N / B_Y;
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
extern __shared__ __nv_bfloat16 shared_real[];
|
| 311 |
+
__nv_bfloat16 *shared_imag = &shared_real[128 * 128];
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
|
| 315 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 316 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 317 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
|
| 318 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[8][8];
|
| 319 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 320 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
|
| 321 |
+
|
| 322 |
+
for (int i = 0; i < n; i++)
|
| 323 |
+
{
|
| 324 |
+
for(int j=0; j< 2; j++){
|
| 325 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 326 |
+
reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
|
| 327 |
+
reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
__syncthreads();
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
for (int i = 0; i < 8; i++){
|
| 335 |
+
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 336 |
+
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
__syncthreads();
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
for (int i = 0; i < n; i++)
|
| 345 |
+
{
|
| 346 |
+
for(int j=0; j< 2; j++){
|
| 347 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 348 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 349 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 350 |
+
reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
__syncthreads();
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
for (int i = 0; i < 8; i++){
|
| 358 |
+
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 359 |
+
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
__syncthreads();
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
for(int t=0; t< 16; t++){
|
| 366 |
+
for (int i = 0; i < n; i++)
|
| 367 |
+
{
|
| 368 |
+
for(int j=0; j< 2; j++){
|
| 369 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 370 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 371 |
+
if(x_gate != nullptr){
|
| 372 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 373 |
+
}else{
|
| 374 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = x[offset + idx];
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
__syncthreads();
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
for (int i = 0; i < 8; i++)
|
| 384 |
+
{
|
| 385 |
+
for (int j = 0; j < 8; j++)
|
| 386 |
+
{
|
| 387 |
+
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
__syncthreads();
|
| 392 |
+
|
| 393 |
+
#pragma unroll
|
| 394 |
+
for (int j = 0; j < 8; j++)
|
| 395 |
+
{
|
| 396 |
+
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 397 |
+
|
| 398 |
+
for (int k = 0; k < 8; k++)
|
| 399 |
+
{
|
| 400 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
#pragma unroll
|
| 405 |
+
|
| 406 |
+
for (int j = 0; j < 8; j++)
|
| 407 |
+
{
|
| 408 |
+
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 409 |
+
|
| 410 |
+
for (int k = 0; k < 8; k++)
|
| 411 |
+
{
|
| 412 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
float2 tmp_real, tmp_imag;
|
| 417 |
+
#pragma unroll
|
| 418 |
+
for (int j = 0; j < 8; j++)
|
| 419 |
+
{
|
| 420 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 421 |
+
{
|
| 422 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 423 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 424 |
+
|
| 425 |
+
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 426 |
+
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
for (int j = 0; j < 8; j++)
|
| 431 |
+
{
|
| 432 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
__syncthreads();
|
| 436 |
+
|
| 437 |
+
#pragma unroll
|
| 438 |
+
for (int i = 0; i < n; i++)
|
| 439 |
+
{
|
| 440 |
+
for(int j=0; j< 2; j++){
|
| 441 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 442 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 443 |
+
out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
__syncthreads();
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
for (int j = 0; j < 8; j++)
|
| 451 |
+
{
|
| 452 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
__syncthreads();
|
| 456 |
+
|
| 457 |
+
#pragma unroll
|
| 458 |
+
for (int i = 0; i < n; i++)
|
| 459 |
+
{
|
| 460 |
+
for(int j=0; j< 2; j++){
|
| 461 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 462 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 463 |
+
out_imag[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
__global__ void butterfly_cuda_kernel_16(
|
| 471 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 472 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 473 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 474 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 475 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 476 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 477 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 478 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 479 |
+
uint B,
|
| 480 |
+
uint H,
|
| 481 |
+
int N)
|
| 482 |
+
{
|
| 483 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 484 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 485 |
+
int idx;
|
| 486 |
+
|
| 487 |
+
int shared_offset;
|
| 488 |
+
const int B_Y = blockDim.y;
|
| 489 |
+
const int n = N / B_Y;
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
__shared__ __nv_bfloat16 x_shared[16 * 64];
|
| 493 |
+
__shared__ __nv_bfloat16 d_f_real_shared[16 * 16];
|
| 494 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[16 * 16];
|
| 495 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
|
| 496 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
|
| 497 |
+
__shared__ float out_real_shared[16 * 64];
|
| 498 |
+
__shared__ float out_imag_shared[16 * 64];
|
| 499 |
+
|
| 500 |
+
// #pragma unroll
|
| 501 |
+
for (int i = 0; i < n; i++)
|
| 502 |
+
{
|
| 503 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 504 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 505 |
+
if(x_gate != nullptr){
|
| 506 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 507 |
+
}else{
|
| 508 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 509 |
+
}
|
| 510 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 511 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 512 |
+
|
| 513 |
+
// #pragma unroll
|
| 514 |
+
if(threadIdx.x < 16 ){
|
| 515 |
+
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 516 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 517 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
__syncthreads();
|
| 522 |
+
|
| 523 |
+
if (threadIdx.y < 4)
|
| 524 |
+
{
|
| 525 |
+
float2 tmp_real, tmp_imag;
|
| 526 |
+
|
| 527 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 528 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 529 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 530 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 531 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
|
| 532 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 533 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
|
| 534 |
+
|
| 535 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
|
| 536 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
|
| 537 |
+
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 538 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 539 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
wmma::fill_fragment(acc_frag_imag, 0.0f);
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
#pragma unroll
|
| 557 |
+
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 558 |
+
{
|
| 559 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
|
| 560 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
|
| 561 |
+
reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
|
| 562 |
+
reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
|
| 563 |
+
}
|
| 564 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 565 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 566 |
+
|
| 567 |
+
}
|
| 568 |
+
__syncthreads();
|
| 569 |
+
|
| 570 |
+
#pragma unroll
|
| 571 |
+
for (int i = 0; i < n; i++)
|
| 572 |
+
{
|
| 573 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 574 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 575 |
+
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
std::vector<torch::Tensor> butterfly_bf16_cuda(
|
| 580 |
+
torch::Tensor x,
|
| 581 |
+
torch::Tensor d_f_real,
|
| 582 |
+
torch::Tensor d_f_imag,
|
| 583 |
+
torch::Tensor twiddle_factors_real,
|
| 584 |
+
torch::Tensor twiddle_factors_imag,
|
| 585 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 586 |
+
)
|
| 587 |
+
{
|
| 588 |
+
|
| 589 |
+
uint B = x.size(0);
|
| 590 |
+
uint H = x.size(1);
|
| 591 |
+
// uint m = x.size(1);
|
| 592 |
+
|
| 593 |
+
// const int TILE_SIZE = 16;
|
| 594 |
+
uint N = x.size(2);
|
| 595 |
+
uint M = x.size(3);
|
| 596 |
+
dim3 gridDim;
|
| 597 |
+
dim3 blockDim;
|
| 598 |
+
|
| 599 |
+
gridDim.y = B;
|
| 600 |
+
gridDim.z = H;
|
| 601 |
+
|
| 602 |
+
torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
|
| 603 |
+
torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
|
| 604 |
+
|
| 605 |
+
//set blockDims
|
| 606 |
+
switch(N){
|
| 607 |
+
case 128:
|
| 608 |
+
blockDim.x = 32;
|
| 609 |
+
blockDim.y = 8;
|
| 610 |
+
break;
|
| 611 |
+
default:
|
| 612 |
+
blockDim.x = 32;
|
| 613 |
+
blockDim.y = 4;
|
| 614 |
+
break;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
//set gridDim.x
|
| 618 |
+
switch(N){
|
| 619 |
+
case 128:
|
| 620 |
+
switch (M){
|
| 621 |
+
case 16384:
|
| 622 |
+
gridDim.x = 128;
|
| 623 |
+
break;
|
| 624 |
+
case 8192:
|
| 625 |
+
gridDim.x = 64;
|
| 626 |
+
break;
|
| 627 |
+
case 4096:
|
| 628 |
+
gridDim.x = 32;
|
| 629 |
+
break;
|
| 630 |
+
default:
|
| 631 |
+
gridDim.x = 256;
|
| 632 |
+
break;
|
| 633 |
+
}
|
| 634 |
+
break;
|
| 635 |
+
default:
|
| 636 |
+
switch (M){
|
| 637 |
+
case 16384:
|
| 638 |
+
gridDim.x = 256;
|
| 639 |
+
break;
|
| 640 |
+
case 8192:
|
| 641 |
+
gridDim.x = 128;
|
| 642 |
+
break;
|
| 643 |
+
case 4096:
|
| 644 |
+
gridDim.x = 64;
|
| 645 |
+
break;
|
| 646 |
+
default:
|
| 647 |
+
gridDim.x = 512;
|
| 648 |
+
break;
|
| 649 |
+
}
|
| 650 |
+
break;
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
switch (N)
|
| 654 |
+
{
|
| 655 |
+
case 16:
|
| 656 |
+
butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 657 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 658 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 659 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 660 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 661 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 662 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 663 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 664 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 665 |
+
B,
|
| 666 |
+
H,
|
| 667 |
+
N);
|
| 668 |
+
break;
|
| 669 |
+
case 32:
|
| 670 |
+
butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 671 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 672 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 673 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 674 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 675 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 676 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 677 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 678 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 679 |
+
B,
|
| 680 |
+
H,
|
| 681 |
+
N);
|
| 682 |
+
break;
|
| 683 |
+
|
| 684 |
+
case 64:
|
| 685 |
+
gridDim.z = H / 16;
|
| 686 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 687 |
+
|
| 688 |
+
butterfly_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
|
| 689 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 690 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 691 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 692 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 693 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 694 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 695 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 696 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 697 |
+
B,
|
| 698 |
+
H,
|
| 699 |
+
N);
|
| 700 |
+
break;
|
| 701 |
+
case 128:
|
| 702 |
+
gridDim.z = H / 16;
|
| 703 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 704 |
+
|
| 705 |
+
butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
|
| 706 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 707 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 708 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 709 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 710 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 711 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 712 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 713 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 714 |
+
B,
|
| 715 |
+
H,
|
| 716 |
+
N);
|
| 717 |
+
break;
|
| 718 |
+
|
| 719 |
+
default:
|
| 720 |
+
printf("Not yet implemented \n");
|
| 721 |
+
break;
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
return {out_real, out_imag};
|
| 725 |
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu
CHANGED
|
@@ -1,723 +1,723 @@
|
|
| 1 |
-
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
-
|
| 3 |
-
#include <torch/extension.h>
|
| 4 |
-
|
| 5 |
-
#include <vector>
|
| 6 |
-
#include <stdio.h>
|
| 7 |
-
#include <mma.h>
|
| 8 |
-
#include <cuda_fp16.h>
|
| 9 |
-
#include <cuda_bf16.h>
|
| 10 |
-
#include "shared.h"
|
| 11 |
-
|
| 12 |
-
using namespace nvcuda;
|
| 13 |
-
|
| 14 |
-
__global__ void butterfly_ifft_cuda_kernel_64(
|
| 15 |
-
const __half2 *__restrict__ x_real,
|
| 16 |
-
const __half2 *__restrict__ x_imag,
|
| 17 |
-
const complex_half_t *__restrict__ d_f,
|
| 18 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 19 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 20 |
-
__half2 *__restrict__ out_real,
|
| 21 |
-
__half2 *__restrict__ out_gate,
|
| 22 |
-
uint B,
|
| 23 |
-
uint H,
|
| 24 |
-
int N)
|
| 25 |
-
{
|
| 26 |
-
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 27 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 28 |
-
int idx;
|
| 29 |
-
int shared_offset;
|
| 30 |
-
const int B_Y = blockDim.y;
|
| 31 |
-
const int n = N / B_Y;
|
| 32 |
-
|
| 33 |
-
extern __shared__ half x_real_shared[];
|
| 34 |
-
half *x_imag_shared = &x_real_shared[N * N];
|
| 35 |
-
half *d_f_real = &x_imag_shared[N * N];
|
| 36 |
-
half *d_f_imag = &d_f_real[N * N];
|
| 37 |
-
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 38 |
-
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 39 |
-
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 40 |
-
|
| 41 |
-
half tmp_real, tmp_imag;
|
| 42 |
-
|
| 43 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
|
| 44 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
|
| 45 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 46 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 47 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
|
| 48 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
|
| 49 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 50 |
-
|
| 51 |
-
// #pragma unroll
|
| 52 |
-
for (int i = 0; i < n; i++)
|
| 53 |
-
{
|
| 54 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 55 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 56 |
-
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 57 |
-
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 58 |
-
|
| 59 |
-
// #pragma unroll
|
| 60 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
|
| 61 |
-
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 62 |
-
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 63 |
-
|
| 64 |
-
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 65 |
-
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
__syncthreads();
|
| 69 |
-
|
| 70 |
-
for (int i = 0; i < 4; i++)
|
| 71 |
-
{
|
| 72 |
-
#pragma unroll
|
| 73 |
-
for (int j = 0; j < 4; j++)
|
| 74 |
-
{
|
| 75 |
-
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 76 |
-
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 77 |
-
}
|
| 78 |
-
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 79 |
-
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
for (int t = 0; t < 16; t++)
|
| 83 |
-
{
|
| 84 |
-
|
| 85 |
-
for (int i = 0; i < n; i++)
|
| 86 |
-
{
|
| 87 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 88 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 89 |
-
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 90 |
-
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
__syncthreads();
|
| 94 |
-
|
| 95 |
-
for (int i = 0; i < 4; i++)
|
| 96 |
-
{
|
| 97 |
-
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 98 |
-
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
for (int j = 0; j < 4; j++)
|
| 102 |
-
{
|
| 103 |
-
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 104 |
-
{
|
| 105 |
-
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 106 |
-
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 107 |
-
b_frag_real[j].x[k] = tmp_real;
|
| 108 |
-
b_frag_imag[j].x[k] = tmp_imag;
|
| 109 |
-
}
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
for (int i = 0; i < 4; i++)
|
| 113 |
-
{
|
| 114 |
-
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 115 |
-
|
| 116 |
-
// bd
|
| 117 |
-
#pragma unroll
|
| 118 |
-
for (int k = 0; k < 4; k++)
|
| 119 |
-
{
|
| 120 |
-
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 124 |
-
{
|
| 125 |
-
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 126 |
-
}
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
for (int i = 0; i < 4; i++)
|
| 130 |
-
{
|
| 131 |
-
// ac - bd
|
| 132 |
-
#pragma unroll
|
| 133 |
-
for (int k = 0; k < 4; k++)
|
| 134 |
-
{
|
| 135 |
-
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 136 |
-
}
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
#pragma unroll
|
| 140 |
-
for (int i = 0; i < 4; i++)
|
| 141 |
-
{
|
| 142 |
-
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
__syncthreads();
|
| 146 |
-
|
| 147 |
-
#pragma unroll
|
| 148 |
-
for (int i = 0; i < n; i++)
|
| 149 |
-
{
|
| 150 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 151 |
-
if(out_gate != nullptr){
|
| 152 |
-
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 153 |
-
}
|
| 154 |
-
else{
|
| 155 |
-
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 156 |
-
}
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
__syncthreads();
|
| 160 |
-
}
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
__global__ void butterfly_ifft_cuda_kernel_32(
|
| 164 |
-
const __half2 *__restrict__ x_real,
|
| 165 |
-
const __half2 *__restrict__ x_imag,
|
| 166 |
-
const complex_half_t *__restrict__ d_f,
|
| 167 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 168 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 169 |
-
__half2 *__restrict__ out_real,
|
| 170 |
-
__half2 *__restrict__ out_gate,
|
| 171 |
-
uint B,
|
| 172 |
-
uint H,
|
| 173 |
-
int N)
|
| 174 |
-
{
|
| 175 |
-
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 176 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 177 |
-
int idx;
|
| 178 |
-
int shared_offset;
|
| 179 |
-
const int B_Y = blockDim.y;
|
| 180 |
-
const int n = N / B_Y;
|
| 181 |
-
|
| 182 |
-
__shared__ half x_real_shared[32 * 64];
|
| 183 |
-
__shared__ half x_imag_shared[32 * 64];
|
| 184 |
-
__shared__ half d_f_real[32 * 32];
|
| 185 |
-
__shared__ half d_f_imag[32 * 32];
|
| 186 |
-
__shared__ half twiddles_real_shared[32 * 64];
|
| 187 |
-
__shared__ half twiddles_imag_shared[32 * 64];
|
| 188 |
-
__shared__ half out_real_shared[32 * 64];
|
| 189 |
-
|
| 190 |
-
// #pragma unroll
|
| 191 |
-
for (int i = 0; i < n; i++)
|
| 192 |
-
{
|
| 193 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 194 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 195 |
-
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 196 |
-
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 197 |
-
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 198 |
-
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 199 |
-
|
| 200 |
-
// #pragma unroll
|
| 201 |
-
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 202 |
-
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 203 |
-
}
|
| 204 |
-
|
| 205 |
-
__syncthreads();
|
| 206 |
-
|
| 207 |
-
if (threadIdx.y < N / 16)
|
| 208 |
-
{
|
| 209 |
-
half tmp_real, tmp_imag;
|
| 210 |
-
|
| 211 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 212 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 213 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 214 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 215 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
|
| 216 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
|
| 217 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 218 |
-
|
| 219 |
-
int t = threadIdx.y * 32;
|
| 220 |
-
|
| 221 |
-
for (int i = 0; i < 2; i++)
|
| 222 |
-
{
|
| 223 |
-
for (int j = 0; j < 2; j++)
|
| 224 |
-
{
|
| 225 |
-
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 226 |
-
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 227 |
-
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 228 |
-
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 229 |
-
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 230 |
-
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 231 |
-
}
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
for (int i = 0; i < 2; i++)
|
| 235 |
-
{
|
| 236 |
-
for (int j = 0; j < 2; j++)
|
| 237 |
-
{
|
| 238 |
-
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 239 |
-
{
|
| 240 |
-
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 241 |
-
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 242 |
-
b_frag_real[i][j].x[k] = tmp_real;
|
| 243 |
-
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 244 |
-
}
|
| 245 |
-
}
|
| 246 |
-
}
|
| 247 |
-
|
| 248 |
-
for (int i = 0; i < 2; i++)
|
| 249 |
-
{
|
| 250 |
-
for (int j = 0; j < 2; j++)
|
| 251 |
-
{
|
| 252 |
-
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 253 |
-
|
| 254 |
-
// bd
|
| 255 |
-
for (int k = 0; k < 2; k++)
|
| 256 |
-
{
|
| 257 |
-
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 261 |
-
{
|
| 262 |
-
acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
|
| 263 |
-
}
|
| 264 |
-
}
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
for (int i = 0; i < 2; i++)
|
| 268 |
-
{
|
| 269 |
-
for (int j = 0; j < 2; j++)
|
| 270 |
-
{
|
| 271 |
-
// ac - bd
|
| 272 |
-
for (int k = 0; k < 2; k++)
|
| 273 |
-
{
|
| 274 |
-
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 275 |
-
}
|
| 276 |
-
}
|
| 277 |
-
}
|
| 278 |
-
|
| 279 |
-
for (int i = 0; i < 2; i++)
|
| 280 |
-
{
|
| 281 |
-
for (int j = 0; j < 2; j++)
|
| 282 |
-
{
|
| 283 |
-
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 284 |
-
}
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
__syncthreads();
|
| 289 |
-
|
| 290 |
-
#pragma unroll
|
| 291 |
-
for (int i = 0; i < n; i++)
|
| 292 |
-
{
|
| 293 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 294 |
-
if(out_gate != nullptr){
|
| 295 |
-
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 296 |
-
}
|
| 297 |
-
else{
|
| 298 |
-
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 299 |
-
}
|
| 300 |
-
}
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
__global__ void butterfly_ifft_cuda_kernel_128(
|
| 305 |
-
const __half2 *__restrict__ x_real,
|
| 306 |
-
const __half2 *__restrict__ x_imag,
|
| 307 |
-
const complex_half_t *__restrict__ d_f,
|
| 308 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 309 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 310 |
-
__half2 *__restrict__ out_real,
|
| 311 |
-
__half2 *__restrict__ out_gate,
|
| 312 |
-
uint B,
|
| 313 |
-
uint H,
|
| 314 |
-
int N)
|
| 315 |
-
{
|
| 316 |
-
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 317 |
-
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 318 |
-
int idx;
|
| 319 |
-
int shared_offset;
|
| 320 |
-
|
| 321 |
-
const int B_Y = 8;
|
| 322 |
-
const int n = 16;
|
| 323 |
-
|
| 324 |
-
extern __shared__ half real_shared[];
|
| 325 |
-
half *imag_shared = &real_shared[128 * 128];
|
| 326 |
-
half *real_shared_2 = &imag_shared[128 * 128];
|
| 327 |
-
half *imag_shared_2 = &real_shared_2[128 * 128];
|
| 328 |
-
|
| 329 |
-
__half2 tmp_real, tmp_imag;
|
| 330 |
-
|
| 331 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[8][8];
|
| 332 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 333 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 334 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
|
| 335 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
|
| 336 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 337 |
-
|
| 338 |
-
for (int i = 0; i < n; i++)
|
| 339 |
-
{
|
| 340 |
-
for(int j=0; j< 4; j++){
|
| 341 |
-
shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
|
| 342 |
-
real_shared_2[shared_offset] = d_f[shared_offset].real();
|
| 343 |
-
imag_shared_2[shared_offset] = d_f[shared_offset].imag();
|
| 344 |
-
}
|
| 345 |
-
}
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
__syncthreads();
|
| 349 |
-
|
| 350 |
-
for (int i = 0; i < n; i++)
|
| 351 |
-
{
|
| 352 |
-
for(int j=0; j< 2; j++){
|
| 353 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 354 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 355 |
-
reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 356 |
-
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 357 |
-
}
|
| 358 |
-
}
|
| 359 |
-
|
| 360 |
-
__syncthreads();
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
for (int i = 0; i < 8; i++){
|
| 364 |
-
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 365 |
-
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
__syncthreads();
|
| 369 |
-
|
| 370 |
-
for (int t = 0; t < 16; t++)
|
| 371 |
-
{
|
| 372 |
-
|
| 373 |
-
for (int i = 0; i < n; i++)
|
| 374 |
-
{
|
| 375 |
-
for(int j=0; j< 2; j++){
|
| 376 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 377 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 378 |
-
reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[offset + idx];
|
| 379 |
-
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[offset + idx];
|
| 380 |
-
}
|
| 381 |
-
}
|
| 382 |
-
|
| 383 |
-
__syncthreads();
|
| 384 |
-
|
| 385 |
-
for (int i = 0; i < 8; i++)
|
| 386 |
-
{
|
| 387 |
-
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 388 |
-
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 389 |
-
}
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
for (int j = 0; j < 8; j++)
|
| 393 |
-
{
|
| 394 |
-
for (int k = 0; k < tw_frag_real[j].num_elements/2; k++)
|
| 395 |
-
{
|
| 396 |
-
tmp_real = __hsub2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]),
|
| 397 |
-
__hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]));
|
| 398 |
-
tmp_imag = __hadd2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]),
|
| 399 |
-
__hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]));
|
| 400 |
-
reinterpret_cast<__half2*>(b_frag_real[j].x)[k] = tmp_real;
|
| 401 |
-
reinterpret_cast<__half2*>(b_frag_imag[j].x)[k] = tmp_imag;
|
| 402 |
-
}
|
| 403 |
-
}
|
| 404 |
-
|
| 405 |
-
for (int i = 0; i < 8; i++){
|
| 406 |
-
for (int j = 0; j < 8; j++){
|
| 407 |
-
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 408 |
-
}
|
| 409 |
-
}
|
| 410 |
-
|
| 411 |
-
__syncthreads();
|
| 412 |
-
|
| 413 |
-
for (int i = 0; i < 8; i++)
|
| 414 |
-
{
|
| 415 |
-
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 416 |
-
|
| 417 |
-
// bd
|
| 418 |
-
#pragma unroll
|
| 419 |
-
for (int k = 0; k < 8; k++)
|
| 420 |
-
{
|
| 421 |
-
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 422 |
-
}
|
| 423 |
-
|
| 424 |
-
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 425 |
-
{
|
| 426 |
-
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 427 |
-
}
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
for (int i = 0; i < 8; i++){
|
| 432 |
-
for (int j = 0; j < 8; j++){
|
| 433 |
-
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 434 |
-
}
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
__syncthreads();
|
| 438 |
-
|
| 439 |
-
for (int i = 0; i < 8; i++)
|
| 440 |
-
{
|
| 441 |
-
// ac - bd
|
| 442 |
-
#pragma unroll
|
| 443 |
-
for (int k = 0; k < 8; k++)
|
| 444 |
-
{
|
| 445 |
-
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 446 |
-
}
|
| 447 |
-
}
|
| 448 |
-
|
| 449 |
-
#pragma unroll
|
| 450 |
-
for (int i = 0; i < 8; i++)
|
| 451 |
-
{
|
| 452 |
-
wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
__syncthreads();
|
| 456 |
-
|
| 457 |
-
#pragma unroll
|
| 458 |
-
for (int i = 0; i < n; i++)
|
| 459 |
-
{
|
| 460 |
-
for(int j=0; j< 2; j++){
|
| 461 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 462 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 463 |
-
if(out_gate != nullptr){
|
| 464 |
-
out_real[offset + idx] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[offset + idx]);
|
| 465 |
-
}
|
| 466 |
-
else{
|
| 467 |
-
out_real[offset + idx] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
|
| 468 |
-
}
|
| 469 |
-
}
|
| 470 |
-
}
|
| 471 |
-
|
| 472 |
-
__syncthreads();
|
| 473 |
-
}
|
| 474 |
-
}
|
| 475 |
-
|
| 476 |
-
__global__ void butterfly_ifft_cuda_kernel_16(
|
| 477 |
-
const __half2 *__restrict__ x_real,
|
| 478 |
-
const __half2 *__restrict__ x_imag,
|
| 479 |
-
const complex_half_t *__restrict__ d_f,
|
| 480 |
-
const __half2 *__restrict__ twiddle_factors_real,
|
| 481 |
-
const __half2 *__restrict__ twiddle_factors_imag,
|
| 482 |
-
__half2 *__restrict__ out_real,
|
| 483 |
-
__half2 *__restrict__ out_gate,
|
| 484 |
-
uint B,
|
| 485 |
-
uint H,
|
| 486 |
-
int N)
|
| 487 |
-
{
|
| 488 |
-
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 489 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 490 |
-
int idx;
|
| 491 |
-
int shared_offset;
|
| 492 |
-
const int B_Y = blockDim.y;
|
| 493 |
-
const int n = N / B_Y;
|
| 494 |
-
|
| 495 |
-
__shared__ half x_real_shared[16 * 64];
|
| 496 |
-
__shared__ half x_imag_shared[16 * 64];
|
| 497 |
-
__shared__ half d_f_real[16 * 16];
|
| 498 |
-
__shared__ half d_f_imag[16 * 16];
|
| 499 |
-
__shared__ half twiddles_real_shared[16 * 64];
|
| 500 |
-
__shared__ half twiddles_imag_shared[16 * 64];
|
| 501 |
-
__shared__ half out_real_shared[16 * 64];
|
| 502 |
-
|
| 503 |
-
// #pragma unroll
|
| 504 |
-
for (int i = 0; i < n; i++)
|
| 505 |
-
{
|
| 506 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 507 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 508 |
-
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 509 |
-
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 510 |
-
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 511 |
-
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 512 |
-
|
| 513 |
-
if(threadIdx.x < 16 ){
|
| 514 |
-
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 515 |
-
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 516 |
-
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 517 |
-
}
|
| 518 |
-
}
|
| 519 |
-
|
| 520 |
-
__syncthreads();
|
| 521 |
-
|
| 522 |
-
//check if it is better to have one warp do all the multiplication or split between warps
|
| 523 |
-
if (threadIdx.y < 4)
|
| 524 |
-
{
|
| 525 |
-
half tmp_real, tmp_imag;
|
| 526 |
-
|
| 527 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 528 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 529 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 530 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 531 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
|
| 532 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
|
| 533 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 534 |
-
|
| 535 |
-
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 536 |
-
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 537 |
-
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 538 |
-
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 539 |
-
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 540 |
-
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 545 |
-
{
|
| 546 |
-
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 547 |
-
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 548 |
-
b_frag_real.x[k] = tmp_real;
|
| 549 |
-
b_frag_imag.x[k] = tmp_imag;
|
| 550 |
-
}
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 554 |
-
|
| 555 |
-
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 556 |
-
|
| 557 |
-
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 558 |
-
acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
|
| 559 |
-
}
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 563 |
-
|
| 564 |
-
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 565 |
-
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
-
__syncthreads();
|
| 569 |
-
|
| 570 |
-
#pragma unroll
|
| 571 |
-
for (int i = 0; i < n; i++)
|
| 572 |
-
{
|
| 573 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 574 |
-
if(out_gate != nullptr){
|
| 575 |
-
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 576 |
-
}
|
| 577 |
-
else{
|
| 578 |
-
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 579 |
-
}
|
| 580 |
-
}
|
| 581 |
-
}
|
| 582 |
-
|
| 583 |
-
torch::Tensor butterfly_ifft_cuda(
|
| 584 |
-
torch::Tensor x_real,
|
| 585 |
-
torch::Tensor x_imag,
|
| 586 |
-
torch::Tensor d_f,
|
| 587 |
-
torch::Tensor twiddle_factors_real,
|
| 588 |
-
torch::Tensor twiddle_factors_imag,
|
| 589 |
-
std::optional<at::Tensor> out_gate = std::nullopt)
|
| 590 |
-
{
|
| 591 |
-
|
| 592 |
-
uint B = x_real.size(0);
|
| 593 |
-
uint H = x_real.size(1);
|
| 594 |
-
// uint m = x.size(1);
|
| 595 |
-
|
| 596 |
-
// const int TILE_SIZE = 16;
|
| 597 |
-
|
| 598 |
-
dim3 gridDim;
|
| 599 |
-
dim3 blockDim;
|
| 600 |
-
|
| 601 |
-
uint N = x_real.size(2);
|
| 602 |
-
uint M = x_real.size(3);
|
| 603 |
-
gridDim.y = B;
|
| 604 |
-
|
| 605 |
-
blockDim.x = 32;
|
| 606 |
-
blockDim.y = 4;
|
| 607 |
-
|
| 608 |
-
torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
|
| 609 |
-
gridDim.z = H;
|
| 610 |
-
|
| 611 |
-
//set blockDims
|
| 612 |
-
switch(N){
|
| 613 |
-
case 128:
|
| 614 |
-
blockDim.x = 32;
|
| 615 |
-
blockDim.y = 8;
|
| 616 |
-
break;
|
| 617 |
-
default:
|
| 618 |
-
blockDim.x = 32;
|
| 619 |
-
blockDim.y = 4;
|
| 620 |
-
break;
|
| 621 |
-
}
|
| 622 |
-
|
| 623 |
-
//set gridDim.x
|
| 624 |
-
switch(N){
|
| 625 |
-
case 128:
|
| 626 |
-
switch (M){
|
| 627 |
-
case 16384:
|
| 628 |
-
gridDim.x = 128;
|
| 629 |
-
break;
|
| 630 |
-
case 8192:
|
| 631 |
-
gridDim.x = 64;
|
| 632 |
-
break;
|
| 633 |
-
case 4096:
|
| 634 |
-
gridDim.x = 32;
|
| 635 |
-
break;
|
| 636 |
-
default:
|
| 637 |
-
gridDim.x = 256;
|
| 638 |
-
break;
|
| 639 |
-
}
|
| 640 |
-
break;
|
| 641 |
-
default:
|
| 642 |
-
switch (M){
|
| 643 |
-
case 16384:
|
| 644 |
-
gridDim.x = 256;
|
| 645 |
-
break;
|
| 646 |
-
case 8192:
|
| 647 |
-
gridDim.x = 128;
|
| 648 |
-
break;
|
| 649 |
-
case 4096:
|
| 650 |
-
gridDim.x = 64;
|
| 651 |
-
break;
|
| 652 |
-
default:
|
| 653 |
-
gridDim.x = 512;
|
| 654 |
-
break;
|
| 655 |
-
}
|
| 656 |
-
break;
|
| 657 |
-
}
|
| 658 |
-
|
| 659 |
-
switch (N)
|
| 660 |
-
{
|
| 661 |
-
case 16:
|
| 662 |
-
butterfly_ifft_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 663 |
-
static_cast<__half2 *>(x_real.data_ptr()),
|
| 664 |
-
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 665 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 666 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 667 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 668 |
-
static_cast<__half2 *>(out.data_ptr()),
|
| 669 |
-
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 670 |
-
B,
|
| 671 |
-
H,
|
| 672 |
-
N);
|
| 673 |
-
break;
|
| 674 |
-
case 32:
|
| 675 |
-
butterfly_ifft_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 676 |
-
static_cast<__half2 *>(x_real.data_ptr()),
|
| 677 |
-
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 678 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 679 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 680 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 681 |
-
static_cast<__half2 *>(out.data_ptr()),
|
| 682 |
-
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 683 |
-
B,
|
| 684 |
-
H,
|
| 685 |
-
N);
|
| 686 |
-
break;
|
| 687 |
-
case 64:
|
| 688 |
-
gridDim.z = H / 16;
|
| 689 |
-
cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 690 |
-
butterfly_ifft_cuda_kernel_64<<<gridDim, blockDim, 8 * N * N * sizeof(half)>>>(
|
| 691 |
-
static_cast<__half2 *>(x_real.data_ptr()),
|
| 692 |
-
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 693 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 694 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 695 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 696 |
-
static_cast<__half2 *>(out.data_ptr()),
|
| 697 |
-
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 698 |
-
B,
|
| 699 |
-
H,
|
| 700 |
-
N);
|
| 701 |
-
break;
|
| 702 |
-
|
| 703 |
-
case 128:
|
| 704 |
-
gridDim.z = H / 16;
|
| 705 |
-
cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536*2);
|
| 706 |
-
butterfly_ifft_cuda_kernel_128<<<gridDim, blockDim, 65536*2>>>(
|
| 707 |
-
static_cast<__half2 *>(x_real.data_ptr()),
|
| 708 |
-
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 709 |
-
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 710 |
-
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 711 |
-
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 712 |
-
static_cast<__half2 *>(out.data_ptr()),
|
| 713 |
-
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 714 |
-
B,
|
| 715 |
-
H,
|
| 716 |
-
N);
|
| 717 |
-
break;
|
| 718 |
-
default:
|
| 719 |
-
printf("Not implemented\n");
|
| 720 |
-
}
|
| 721 |
-
|
| 722 |
-
return out;
|
| 723 |
-
}
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include "shared.h"
|
| 11 |
+
|
| 12 |
+
using namespace nvcuda;
|
| 13 |
+
|
| 14 |
+
__global__ void butterfly_ifft_cuda_kernel_64(
|
| 15 |
+
const __half2 *__restrict__ x_real,
|
| 16 |
+
const __half2 *__restrict__ x_imag,
|
| 17 |
+
const complex_half_t *__restrict__ d_f,
|
| 18 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 19 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 20 |
+
__half2 *__restrict__ out_real,
|
| 21 |
+
__half2 *__restrict__ out_gate,
|
| 22 |
+
uint B,
|
| 23 |
+
uint H,
|
| 24 |
+
int N)
|
| 25 |
+
{
|
| 26 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 27 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 28 |
+
int idx;
|
| 29 |
+
int shared_offset;
|
| 30 |
+
const int B_Y = blockDim.y;
|
| 31 |
+
const int n = N / B_Y;
|
| 32 |
+
|
| 33 |
+
extern __shared__ half x_real_shared[];
|
| 34 |
+
half *x_imag_shared = &x_real_shared[N * N];
|
| 35 |
+
half *d_f_real = &x_imag_shared[N * N];
|
| 36 |
+
half *d_f_imag = &d_f_real[N * N];
|
| 37 |
+
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 38 |
+
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 39 |
+
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 40 |
+
|
| 41 |
+
half tmp_real, tmp_imag;
|
| 42 |
+
|
| 43 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
|
| 44 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
|
| 45 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 46 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 47 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
|
| 48 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
|
| 49 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 50 |
+
|
| 51 |
+
// #pragma unroll
|
| 52 |
+
for (int i = 0; i < n; i++)
|
| 53 |
+
{
|
| 54 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 55 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 56 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 57 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 58 |
+
|
| 59 |
+
// #pragma unroll
|
| 60 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
|
| 61 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 62 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 63 |
+
|
| 64 |
+
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 65 |
+
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
__syncthreads();
|
| 69 |
+
|
| 70 |
+
for (int i = 0; i < 4; i++)
|
| 71 |
+
{
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (int j = 0; j < 4; j++)
|
| 74 |
+
{
|
| 75 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 76 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 77 |
+
}
|
| 78 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 79 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
for (int t = 0; t < 16; t++)
|
| 83 |
+
{
|
| 84 |
+
|
| 85 |
+
for (int i = 0; i < n; i++)
|
| 86 |
+
{
|
| 87 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 88 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 89 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 90 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
__syncthreads();
|
| 94 |
+
|
| 95 |
+
for (int i = 0; i < 4; i++)
|
| 96 |
+
{
|
| 97 |
+
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 98 |
+
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
for (int j = 0; j < 4; j++)
|
| 102 |
+
{
|
| 103 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 104 |
+
{
|
| 105 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 106 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 107 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 108 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
for (int i = 0; i < 4; i++)
|
| 113 |
+
{
|
| 114 |
+
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 115 |
+
|
| 116 |
+
// bd
|
| 117 |
+
#pragma unroll
|
| 118 |
+
for (int k = 0; k < 4; k++)
|
| 119 |
+
{
|
| 120 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 124 |
+
{
|
| 125 |
+
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for (int i = 0; i < 4; i++)
|
| 130 |
+
{
|
| 131 |
+
// ac - bd
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (int k = 0; k < 4; k++)
|
| 134 |
+
{
|
| 135 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
#pragma unroll
|
| 140 |
+
for (int i = 0; i < 4; i++)
|
| 141 |
+
{
|
| 142 |
+
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
__syncthreads();
|
| 146 |
+
|
| 147 |
+
#pragma unroll
|
| 148 |
+
for (int i = 0; i < n; i++)
|
| 149 |
+
{
|
| 150 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 151 |
+
if(out_gate != nullptr){
|
| 152 |
+
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 153 |
+
}
|
| 154 |
+
else{
|
| 155 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
__syncthreads();
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
__global__ void butterfly_ifft_cuda_kernel_32(
|
| 164 |
+
const __half2 *__restrict__ x_real,
|
| 165 |
+
const __half2 *__restrict__ x_imag,
|
| 166 |
+
const complex_half_t *__restrict__ d_f,
|
| 167 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 168 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 169 |
+
__half2 *__restrict__ out_real,
|
| 170 |
+
__half2 *__restrict__ out_gate,
|
| 171 |
+
uint B,
|
| 172 |
+
uint H,
|
| 173 |
+
int N)
|
| 174 |
+
{
|
| 175 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 176 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 177 |
+
int idx;
|
| 178 |
+
int shared_offset;
|
| 179 |
+
const int B_Y = blockDim.y;
|
| 180 |
+
const int n = N / B_Y;
|
| 181 |
+
|
| 182 |
+
__shared__ half x_real_shared[32 * 64];
|
| 183 |
+
__shared__ half x_imag_shared[32 * 64];
|
| 184 |
+
__shared__ half d_f_real[32 * 32];
|
| 185 |
+
__shared__ half d_f_imag[32 * 32];
|
| 186 |
+
__shared__ half twiddles_real_shared[32 * 64];
|
| 187 |
+
__shared__ half twiddles_imag_shared[32 * 64];
|
| 188 |
+
__shared__ half out_real_shared[32 * 64];
|
| 189 |
+
|
| 190 |
+
// #pragma unroll
|
| 191 |
+
for (int i = 0; i < n; i++)
|
| 192 |
+
{
|
| 193 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 194 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 195 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 196 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 197 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 198 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 199 |
+
|
| 200 |
+
// #pragma unroll
|
| 201 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 202 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
__syncthreads();
|
| 206 |
+
|
| 207 |
+
if (threadIdx.y < N / 16)
|
| 208 |
+
{
|
| 209 |
+
half tmp_real, tmp_imag;
|
| 210 |
+
|
| 211 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 212 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 213 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 214 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 215 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
|
| 216 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
|
| 217 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 218 |
+
|
| 219 |
+
int t = threadIdx.y * 32;
|
| 220 |
+
|
| 221 |
+
for (int i = 0; i < 2; i++)
|
| 222 |
+
{
|
| 223 |
+
for (int j = 0; j < 2; j++)
|
| 224 |
+
{
|
| 225 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 226 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 227 |
+
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 228 |
+
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 229 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 230 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
for (int i = 0; i < 2; i++)
|
| 235 |
+
{
|
| 236 |
+
for (int j = 0; j < 2; j++)
|
| 237 |
+
{
|
| 238 |
+
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 239 |
+
{
|
| 240 |
+
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 241 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 242 |
+
b_frag_real[i][j].x[k] = tmp_real;
|
| 243 |
+
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
for (int i = 0; i < 2; i++)
|
| 249 |
+
{
|
| 250 |
+
for (int j = 0; j < 2; j++)
|
| 251 |
+
{
|
| 252 |
+
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 253 |
+
|
| 254 |
+
// bd
|
| 255 |
+
for (int k = 0; k < 2; k++)
|
| 256 |
+
{
|
| 257 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 261 |
+
{
|
| 262 |
+
acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
for (int i = 0; i < 2; i++)
|
| 268 |
+
{
|
| 269 |
+
for (int j = 0; j < 2; j++)
|
| 270 |
+
{
|
| 271 |
+
// ac - bd
|
| 272 |
+
for (int k = 0; k < 2; k++)
|
| 273 |
+
{
|
| 274 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
for (int i = 0; i < 2; i++)
|
| 280 |
+
{
|
| 281 |
+
for (int j = 0; j < 2; j++)
|
| 282 |
+
{
|
| 283 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
__syncthreads();
|
| 289 |
+
|
| 290 |
+
#pragma unroll
|
| 291 |
+
for (int i = 0; i < n; i++)
|
| 292 |
+
{
|
| 293 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 294 |
+
if(out_gate != nullptr){
|
| 295 |
+
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 296 |
+
}
|
| 297 |
+
else{
|
| 298 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
__global__ void butterfly_ifft_cuda_kernel_128(
|
| 305 |
+
const __half2 *__restrict__ x_real,
|
| 306 |
+
const __half2 *__restrict__ x_imag,
|
| 307 |
+
const complex_half_t *__restrict__ d_f,
|
| 308 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 309 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 310 |
+
__half2 *__restrict__ out_real,
|
| 311 |
+
__half2 *__restrict__ out_gate,
|
| 312 |
+
uint B,
|
| 313 |
+
uint H,
|
| 314 |
+
int N)
|
| 315 |
+
{
|
| 316 |
+
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 317 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 318 |
+
int idx;
|
| 319 |
+
int shared_offset;
|
| 320 |
+
|
| 321 |
+
const int B_Y = 8;
|
| 322 |
+
const int n = 16;
|
| 323 |
+
|
| 324 |
+
extern __shared__ half real_shared[];
|
| 325 |
+
half *imag_shared = &real_shared[128 * 128];
|
| 326 |
+
half *real_shared_2 = &imag_shared[128 * 128];
|
| 327 |
+
half *imag_shared_2 = &real_shared_2[128 * 128];
|
| 328 |
+
|
| 329 |
+
__half2 tmp_real, tmp_imag;
|
| 330 |
+
|
| 331 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[8][8];
|
| 332 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 333 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 334 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
|
| 335 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
|
| 336 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 337 |
+
|
| 338 |
+
for (int i = 0; i < n; i++)
|
| 339 |
+
{
|
| 340 |
+
for(int j=0; j< 4; j++){
|
| 341 |
+
shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
|
| 342 |
+
real_shared_2[shared_offset] = d_f[shared_offset].real();
|
| 343 |
+
imag_shared_2[shared_offset] = d_f[shared_offset].imag();
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
__syncthreads();
|
| 349 |
+
|
| 350 |
+
for (int i = 0; i < n; i++)
|
| 351 |
+
{
|
| 352 |
+
for(int j=0; j< 2; j++){
|
| 353 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 354 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 355 |
+
reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 356 |
+
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
__syncthreads();
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
for (int i = 0; i < 8; i++){
|
| 364 |
+
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 365 |
+
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
__syncthreads();
|
| 369 |
+
|
| 370 |
+
for (int t = 0; t < 16; t++)
|
| 371 |
+
{
|
| 372 |
+
|
| 373 |
+
for (int i = 0; i < n; i++)
|
| 374 |
+
{
|
| 375 |
+
for(int j=0; j< 2; j++){
|
| 376 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 377 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 378 |
+
reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[offset + idx];
|
| 379 |
+
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[offset + idx];
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
__syncthreads();
|
| 384 |
+
|
| 385 |
+
for (int i = 0; i < 8; i++)
|
| 386 |
+
{
|
| 387 |
+
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 388 |
+
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
for (int j = 0; j < 8; j++)
|
| 393 |
+
{
|
| 394 |
+
for (int k = 0; k < tw_frag_real[j].num_elements/2; k++)
|
| 395 |
+
{
|
| 396 |
+
tmp_real = __hsub2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]),
|
| 397 |
+
__hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]));
|
| 398 |
+
tmp_imag = __hadd2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]),
|
| 399 |
+
__hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]));
|
| 400 |
+
reinterpret_cast<__half2*>(b_frag_real[j].x)[k] = tmp_real;
|
| 401 |
+
reinterpret_cast<__half2*>(b_frag_imag[j].x)[k] = tmp_imag;
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
for (int i = 0; i < 8; i++){
|
| 406 |
+
for (int j = 0; j < 8; j++){
|
| 407 |
+
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
__syncthreads();
|
| 412 |
+
|
| 413 |
+
for (int i = 0; i < 8; i++)
|
| 414 |
+
{
|
| 415 |
+
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 416 |
+
|
| 417 |
+
// bd
|
| 418 |
+
#pragma unroll
|
| 419 |
+
for (int k = 0; k < 8; k++)
|
| 420 |
+
{
|
| 421 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 425 |
+
{
|
| 426 |
+
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
for (int i = 0; i < 8; i++){
|
| 432 |
+
for (int j = 0; j < 8; j++){
|
| 433 |
+
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
__syncthreads();
|
| 438 |
+
|
| 439 |
+
for (int i = 0; i < 8; i++)
|
| 440 |
+
{
|
| 441 |
+
// ac - bd
|
| 442 |
+
#pragma unroll
|
| 443 |
+
for (int k = 0; k < 8; k++)
|
| 444 |
+
{
|
| 445 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
#pragma unroll
|
| 450 |
+
for (int i = 0; i < 8; i++)
|
| 451 |
+
{
|
| 452 |
+
wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
__syncthreads();
|
| 456 |
+
|
| 457 |
+
#pragma unroll
|
| 458 |
+
for (int i = 0; i < n; i++)
|
| 459 |
+
{
|
| 460 |
+
for(int j=0; j< 2; j++){
|
| 461 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 462 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 463 |
+
if(out_gate != nullptr){
|
| 464 |
+
out_real[offset + idx] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[offset + idx]);
|
| 465 |
+
}
|
| 466 |
+
else{
|
| 467 |
+
out_real[offset + idx] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
|
| 468 |
+
}
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
__syncthreads();
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
__global__ void butterfly_ifft_cuda_kernel_16(
|
| 477 |
+
const __half2 *__restrict__ x_real,
|
| 478 |
+
const __half2 *__restrict__ x_imag,
|
| 479 |
+
const complex_half_t *__restrict__ d_f,
|
| 480 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 481 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 482 |
+
__half2 *__restrict__ out_real,
|
| 483 |
+
__half2 *__restrict__ out_gate,
|
| 484 |
+
uint B,
|
| 485 |
+
uint H,
|
| 486 |
+
int N)
|
| 487 |
+
{
|
| 488 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 489 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 490 |
+
int idx;
|
| 491 |
+
int shared_offset;
|
| 492 |
+
const int B_Y = blockDim.y;
|
| 493 |
+
const int n = N / B_Y;
|
| 494 |
+
|
| 495 |
+
__shared__ half x_real_shared[16 * 64];
|
| 496 |
+
__shared__ half x_imag_shared[16 * 64];
|
| 497 |
+
__shared__ half d_f_real[16 * 16];
|
| 498 |
+
__shared__ half d_f_imag[16 * 16];
|
| 499 |
+
__shared__ half twiddles_real_shared[16 * 64];
|
| 500 |
+
__shared__ half twiddles_imag_shared[16 * 64];
|
| 501 |
+
__shared__ half out_real_shared[16 * 64];
|
| 502 |
+
|
| 503 |
+
// #pragma unroll
|
| 504 |
+
for (int i = 0; i < n; i++)
|
| 505 |
+
{
|
| 506 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 507 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 508 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 509 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 510 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 511 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 512 |
+
|
| 513 |
+
if(threadIdx.x < 16 ){
|
| 514 |
+
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 515 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 516 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
__syncthreads();
|
| 521 |
+
|
| 522 |
+
//check if it is better to have one warp do all the multiplication or split between warps
|
| 523 |
+
if (threadIdx.y < 4)
|
| 524 |
+
{
|
| 525 |
+
half tmp_real, tmp_imag;
|
| 526 |
+
|
| 527 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 528 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 529 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 530 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 531 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
|
| 532 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
|
| 533 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 534 |
+
|
| 535 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 536 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 537 |
+
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 538 |
+
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 539 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 540 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 545 |
+
{
|
| 546 |
+
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 547 |
+
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 548 |
+
b_frag_real.x[k] = tmp_real;
|
| 549 |
+
b_frag_imag.x[k] = tmp_imag;
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 554 |
+
|
| 555 |
+
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 556 |
+
|
| 557 |
+
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 558 |
+
acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 563 |
+
|
| 564 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 565 |
+
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
__syncthreads();
|
| 569 |
+
|
| 570 |
+
#pragma unroll
|
| 571 |
+
for (int i = 0; i < n; i++)
|
| 572 |
+
{
|
| 573 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 574 |
+
if(out_gate != nullptr){
|
| 575 |
+
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 576 |
+
}
|
| 577 |
+
else{
|
| 578 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 579 |
+
}
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
torch::Tensor butterfly_ifft_cuda(
|
| 584 |
+
torch::Tensor x_real,
|
| 585 |
+
torch::Tensor x_imag,
|
| 586 |
+
torch::Tensor d_f,
|
| 587 |
+
torch::Tensor twiddle_factors_real,
|
| 588 |
+
torch::Tensor twiddle_factors_imag,
|
| 589 |
+
std::optional<at::Tensor> out_gate = std::nullopt)
|
| 590 |
+
{
|
| 591 |
+
|
| 592 |
+
uint B = x_real.size(0);
|
| 593 |
+
uint H = x_real.size(1);
|
| 594 |
+
// uint m = x.size(1);
|
| 595 |
+
|
| 596 |
+
// const int TILE_SIZE = 16;
|
| 597 |
+
|
| 598 |
+
dim3 gridDim;
|
| 599 |
+
dim3 blockDim;
|
| 600 |
+
|
| 601 |
+
uint N = x_real.size(2);
|
| 602 |
+
uint M = x_real.size(3);
|
| 603 |
+
gridDim.y = B;
|
| 604 |
+
|
| 605 |
+
blockDim.x = 32;
|
| 606 |
+
blockDim.y = 4;
|
| 607 |
+
|
| 608 |
+
torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
|
| 609 |
+
gridDim.z = H;
|
| 610 |
+
|
| 611 |
+
//set blockDims
|
| 612 |
+
switch(N){
|
| 613 |
+
case 128:
|
| 614 |
+
blockDim.x = 32;
|
| 615 |
+
blockDim.y = 8;
|
| 616 |
+
break;
|
| 617 |
+
default:
|
| 618 |
+
blockDim.x = 32;
|
| 619 |
+
blockDim.y = 4;
|
| 620 |
+
break;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
//set gridDim.x
|
| 624 |
+
switch(N){
|
| 625 |
+
case 128:
|
| 626 |
+
switch (M){
|
| 627 |
+
case 16384:
|
| 628 |
+
gridDim.x = 128;
|
| 629 |
+
break;
|
| 630 |
+
case 8192:
|
| 631 |
+
gridDim.x = 64;
|
| 632 |
+
break;
|
| 633 |
+
case 4096:
|
| 634 |
+
gridDim.x = 32;
|
| 635 |
+
break;
|
| 636 |
+
default:
|
| 637 |
+
gridDim.x = 256;
|
| 638 |
+
break;
|
| 639 |
+
}
|
| 640 |
+
break;
|
| 641 |
+
default:
|
| 642 |
+
switch (M){
|
| 643 |
+
case 16384:
|
| 644 |
+
gridDim.x = 256;
|
| 645 |
+
break;
|
| 646 |
+
case 8192:
|
| 647 |
+
gridDim.x = 128;
|
| 648 |
+
break;
|
| 649 |
+
case 4096:
|
| 650 |
+
gridDim.x = 64;
|
| 651 |
+
break;
|
| 652 |
+
default:
|
| 653 |
+
gridDim.x = 512;
|
| 654 |
+
break;
|
| 655 |
+
}
|
| 656 |
+
break;
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
switch (N)
|
| 660 |
+
{
|
| 661 |
+
case 16:
|
| 662 |
+
butterfly_ifft_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 663 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 664 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 665 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 666 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 667 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 668 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 669 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 670 |
+
B,
|
| 671 |
+
H,
|
| 672 |
+
N);
|
| 673 |
+
break;
|
| 674 |
+
case 32:
|
| 675 |
+
butterfly_ifft_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 676 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 677 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 678 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 679 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 680 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 681 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 682 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 683 |
+
B,
|
| 684 |
+
H,
|
| 685 |
+
N);
|
| 686 |
+
break;
|
| 687 |
+
case 64:
|
| 688 |
+
gridDim.z = H / 16;
|
| 689 |
+
cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 690 |
+
butterfly_ifft_cuda_kernel_64<<<gridDim, blockDim, 8 * N * N * sizeof(half)>>>(
|
| 691 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 692 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 693 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 694 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 695 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 696 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 697 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 698 |
+
B,
|
| 699 |
+
H,
|
| 700 |
+
N);
|
| 701 |
+
break;
|
| 702 |
+
|
| 703 |
+
case 128:
|
| 704 |
+
gridDim.z = H / 16;
|
| 705 |
+
cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536*2);
|
| 706 |
+
butterfly_ifft_cuda_kernel_128<<<gridDim, blockDim, 65536*2>>>(
|
| 707 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 708 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 709 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 710 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 711 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 712 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 713 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 714 |
+
B,
|
| 715 |
+
H,
|
| 716 |
+
N);
|
| 717 |
+
break;
|
| 718 |
+
default:
|
| 719 |
+
printf("Not implemented\n");
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
return out;
|
| 723 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu
CHANGED
|
@@ -1,705 +1,705 @@
|
|
| 1 |
-
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
-
|
| 3 |
-
#include <torch/extension.h>
|
| 4 |
-
|
| 5 |
-
#include <vector>
|
| 6 |
-
#include <stdio.h>
|
| 7 |
-
#include <mma.h>
|
| 8 |
-
#include <cuda_fp16.h>
|
| 9 |
-
#include <cuda_bf16.h>
|
| 10 |
-
#include <cuda_runtime.h>
|
| 11 |
-
#include "shared.h"
|
| 12 |
-
|
| 13 |
-
using namespace nvcuda;
|
| 14 |
-
|
| 15 |
-
__global__ void butterfly_ifft_bf16_cuda_kernel_64(
|
| 16 |
-
const __nv_bfloat162 *__restrict__ x_real,
|
| 17 |
-
const __nv_bfloat162 *__restrict__ x_imag,
|
| 18 |
-
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
-
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
-
__nv_bfloat162 *__restrict__ out_gate,
|
| 24 |
-
uint B,
|
| 25 |
-
uint H,
|
| 26 |
-
int N)
|
| 27 |
-
{
|
| 28 |
-
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 29 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 30 |
-
int idx;
|
| 31 |
-
int shared_offset;
|
| 32 |
-
const int B_Y = blockDim.y;
|
| 33 |
-
const int n = N / B_Y;
|
| 34 |
-
|
| 35 |
-
extern __shared__ __nv_bfloat16 x_real_shared[];
|
| 36 |
-
__nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
|
| 37 |
-
__nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
|
| 38 |
-
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 39 |
-
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 40 |
-
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 41 |
-
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 42 |
-
|
| 43 |
-
__nv_bfloat16 tmp_real, tmp_imag;
|
| 44 |
-
|
| 45 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4][4];
|
| 46 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4][4];
|
| 47 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 48 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 49 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
|
| 50 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
|
| 51 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 52 |
-
|
| 53 |
-
// #pragma unroll
|
| 54 |
-
for (int i = 0; i < n; i++)
|
| 55 |
-
{
|
| 56 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 57 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 58 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 59 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 60 |
-
|
| 61 |
-
// #pragma unroll
|
| 62 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 63 |
-
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 64 |
-
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
__syncthreads();
|
| 68 |
-
|
| 69 |
-
for (int i = 0; i < 4; i++)
|
| 70 |
-
{
|
| 71 |
-
#pragma unroll
|
| 72 |
-
for (int j = 0; j < 4; j++)
|
| 73 |
-
{
|
| 74 |
-
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 75 |
-
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 76 |
-
}
|
| 77 |
-
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 78 |
-
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
for (int t = 0; t < 16; t++)
|
| 82 |
-
{
|
| 83 |
-
|
| 84 |
-
for (int i = 0; i < n; i++)
|
| 85 |
-
{
|
| 86 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 87 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 88 |
-
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 89 |
-
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
__syncthreads();
|
| 93 |
-
|
| 94 |
-
for (int i = 0; i < 4; i++)
|
| 95 |
-
{
|
| 96 |
-
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 97 |
-
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
for (int j = 0; j < 4; j++)
|
| 101 |
-
{
|
| 102 |
-
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 103 |
-
{
|
| 104 |
-
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 105 |
-
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 106 |
-
b_frag_real[j].x[k] = tmp_real;
|
| 107 |
-
b_frag_imag[j].x[k] = tmp_imag;
|
| 108 |
-
}
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
for (int i = 0; i < 4; i++)
|
| 112 |
-
{
|
| 113 |
-
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 114 |
-
|
| 115 |
-
// bd
|
| 116 |
-
#pragma unroll
|
| 117 |
-
for (int k = 0; k < 4; k++)
|
| 118 |
-
{
|
| 119 |
-
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 123 |
-
{
|
| 124 |
-
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 125 |
-
}
|
| 126 |
-
}
|
| 127 |
-
|
| 128 |
-
for (int i = 0; i < 4; i++)
|
| 129 |
-
{
|
| 130 |
-
// ac - bd
|
| 131 |
-
#pragma unroll
|
| 132 |
-
for (int k = 0; k < 4; k++)
|
| 133 |
-
{
|
| 134 |
-
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 135 |
-
}
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
#pragma unroll
|
| 139 |
-
for (int i = 0; i < 4; i++)
|
| 140 |
-
{
|
| 141 |
-
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
__syncthreads();
|
| 145 |
-
|
| 146 |
-
#pragma unroll
|
| 147 |
-
for (int i = 0; i < n; i++)
|
| 148 |
-
{
|
| 149 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 150 |
-
if(out_gate != nullptr){
|
| 151 |
-
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]); ;
|
| 152 |
-
}else{
|
| 153 |
-
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 154 |
-
}
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
__syncthreads();
|
| 158 |
-
}
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
__global__ void butterfly_ifft_bf16_cuda_kernel_32(
|
| 162 |
-
const __nv_bfloat162 *__restrict__ x_real,
|
| 163 |
-
const __nv_bfloat162 *__restrict__ x_imag,
|
| 164 |
-
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 165 |
-
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 166 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 167 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 168 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 169 |
-
__nv_bfloat162 *__restrict__ out_gate,
|
| 170 |
-
uint B,
|
| 171 |
-
uint H,
|
| 172 |
-
int N)
|
| 173 |
-
{
|
| 174 |
-
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 175 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 176 |
-
int idx;
|
| 177 |
-
int shared_offset;
|
| 178 |
-
const int B_Y = blockDim.y;
|
| 179 |
-
const int n = N / B_Y;
|
| 180 |
-
|
| 181 |
-
__shared__ __nv_bfloat16 x_real_shared[32 * 64];
|
| 182 |
-
__shared__ __nv_bfloat16 x_imag_shared[32 * 64];
|
| 183 |
-
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 184 |
-
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 185 |
-
__shared__ float out_real_shared[32 * 64];
|
| 186 |
-
|
| 187 |
-
// #pragma unroll
|
| 188 |
-
for (int i = 0; i < n; i++)
|
| 189 |
-
{
|
| 190 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 191 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 192 |
-
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 193 |
-
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 194 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 195 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
__syncthreads();
|
| 199 |
-
|
| 200 |
-
if (threadIdx.y < N / 16)
|
| 201 |
-
{
|
| 202 |
-
__nv_bfloat16 tmp_real, tmp_imag;
|
| 203 |
-
|
| 204 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 205 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 206 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 207 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 208 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
|
| 209 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
|
| 210 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 211 |
-
|
| 212 |
-
int t = threadIdx.y * 32;
|
| 213 |
-
|
| 214 |
-
for (int i = 0; i < 2; i++)
|
| 215 |
-
{
|
| 216 |
-
for (int j = 0; j < 2; j++)
|
| 217 |
-
{
|
| 218 |
-
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 219 |
-
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 220 |
-
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 221 |
-
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 222 |
-
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 223 |
-
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 224 |
-
}
|
| 225 |
-
}
|
| 226 |
-
|
| 227 |
-
for (int i = 0; i < 2; i++)
|
| 228 |
-
{
|
| 229 |
-
for (int j = 0; j < 2; j++)
|
| 230 |
-
{
|
| 231 |
-
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 232 |
-
{
|
| 233 |
-
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 234 |
-
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 235 |
-
b_frag_real[i][j].x[k] = tmp_real;
|
| 236 |
-
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 237 |
-
}
|
| 238 |
-
}
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
for (int i = 0; i < 2; i++)
|
| 242 |
-
{
|
| 243 |
-
for (int j = 0; j < 2; j++)
|
| 244 |
-
{
|
| 245 |
-
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 246 |
-
|
| 247 |
-
// bd
|
| 248 |
-
for (int k = 0; k < 2; k++)
|
| 249 |
-
{
|
| 250 |
-
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 254 |
-
{
|
| 255 |
-
acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
|
| 256 |
-
}
|
| 257 |
-
}
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
for (int i = 0; i < 2; i++)
|
| 261 |
-
{
|
| 262 |
-
for (int j = 0; j < 2; j++)
|
| 263 |
-
{
|
| 264 |
-
// ac - bd
|
| 265 |
-
for (int k = 0; k < 2; k++)
|
| 266 |
-
{
|
| 267 |
-
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 268 |
-
}
|
| 269 |
-
}
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
for (int i = 0; i < 2; i++)
|
| 273 |
-
{
|
| 274 |
-
for (int j = 0; j < 2; j++)
|
| 275 |
-
{
|
| 276 |
-
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 277 |
-
}
|
| 278 |
-
}
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
__syncthreads();
|
| 282 |
-
|
| 283 |
-
#pragma unroll
|
| 284 |
-
for (int i = 0; i < n; i++)
|
| 285 |
-
{
|
| 286 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 287 |
-
if(out_gate != nullptr){
|
| 288 |
-
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
|
| 289 |
-
}else{
|
| 290 |
-
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 291 |
-
}
|
| 292 |
-
}
|
| 293 |
-
}
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
__global__ void butterfly_ifft_bf16_cuda_kernel_128(
|
| 297 |
-
const __nv_bfloat162 *__restrict__ x_real,
|
| 298 |
-
const __nv_bfloat162 *__restrict__ x_imag,
|
| 299 |
-
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 300 |
-
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 301 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 302 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 303 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 304 |
-
__nv_bfloat162 *__restrict__ out_gate,
|
| 305 |
-
uint B,
|
| 306 |
-
uint H,
|
| 307 |
-
int N)
|
| 308 |
-
{
|
| 309 |
-
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 310 |
-
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 311 |
-
int idx;
|
| 312 |
-
int shared_offset;
|
| 313 |
-
const int B_Y = blockDim.y;
|
| 314 |
-
const int n = N / B_Y;
|
| 315 |
-
|
| 316 |
-
extern __shared__ __nv_bfloat16 real_shared[];
|
| 317 |
-
__nv_bfloat16 *imag_shared = &real_shared[128 * 128];
|
| 318 |
-
__nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
|
| 319 |
-
__nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
|
| 320 |
-
|
| 321 |
-
__nv_bfloat16 tmp_real, tmp_imag;
|
| 322 |
-
|
| 323 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[8][8];
|
| 324 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 325 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 326 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
|
| 327 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
|
| 328 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 329 |
-
|
| 330 |
-
for (int i = 0; i < n; i++)
|
| 331 |
-
{
|
| 332 |
-
for(int j=0; j< 2; j++){
|
| 333 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 334 |
-
reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
|
| 335 |
-
reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
|
| 336 |
-
}
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
for (int i = 0; i < n; i++)
|
| 340 |
-
{
|
| 341 |
-
for(int j=0; j< 2; j++){
|
| 342 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 343 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 344 |
-
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 345 |
-
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 346 |
-
}
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
-
__syncthreads();
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
for (int i = 0; i < 8; i++){
|
| 353 |
-
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 354 |
-
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 355 |
-
}
|
| 356 |
-
|
| 357 |
-
__syncthreads();
|
| 358 |
-
|
| 359 |
-
for (int t = 0; t < 16; t++)
|
| 360 |
-
{
|
| 361 |
-
for (int i = 0; i < 8; i++){
|
| 362 |
-
for (int j = 0; j < 8; j++){
|
| 363 |
-
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 364 |
-
}
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
for (int i = 0; i < n; i++)
|
| 368 |
-
{
|
| 369 |
-
for(int j=0; j< 2; j++){
|
| 370 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 371 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 372 |
-
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[offset + idx];
|
| 373 |
-
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[offset + idx];
|
| 374 |
-
}
|
| 375 |
-
}
|
| 376 |
-
|
| 377 |
-
__syncthreads();
|
| 378 |
-
|
| 379 |
-
for (int i = 0; i < 8; i++)
|
| 380 |
-
{
|
| 381 |
-
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 382 |
-
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 383 |
-
}
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
for (int j = 0; j < 8; j++)
|
| 387 |
-
{
|
| 388 |
-
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 389 |
-
{
|
| 390 |
-
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 391 |
-
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 392 |
-
b_frag_real[j].x[k] = tmp_real;
|
| 393 |
-
b_frag_imag[j].x[k] = tmp_imag;
|
| 394 |
-
}
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
for (int i = 0; i < 8; i++)
|
| 398 |
-
{
|
| 399 |
-
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 400 |
-
|
| 401 |
-
// bd
|
| 402 |
-
#pragma unroll
|
| 403 |
-
for (int k = 0; k < 8; k++)
|
| 404 |
-
{
|
| 405 |
-
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 409 |
-
{
|
| 410 |
-
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 411 |
-
}
|
| 412 |
-
}
|
| 413 |
-
|
| 414 |
-
for (int i = 0; i < 8; i++){
|
| 415 |
-
for (int j = 0; j < 8; j++){
|
| 416 |
-
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 417 |
-
}
|
| 418 |
-
}
|
| 419 |
-
|
| 420 |
-
for (int i = 0; i < 8; i++)
|
| 421 |
-
{
|
| 422 |
-
// ac - bd
|
| 423 |
-
#pragma unroll
|
| 424 |
-
for (int k = 0; k < 8; k++)
|
| 425 |
-
{
|
| 426 |
-
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 427 |
-
}
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
#pragma unroll
|
| 431 |
-
for (int i = 0; i < 8; i++)
|
| 432 |
-
{
|
| 433 |
-
//wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 434 |
-
wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
__syncthreads();
|
| 438 |
-
|
| 439 |
-
#pragma unroll
|
| 440 |
-
for (int i = 0; i < n; i++)
|
| 441 |
-
{
|
| 442 |
-
for(int j=0; j< 2; j++){
|
| 443 |
-
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 444 |
-
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 445 |
-
if(out_gate != nullptr){
|
| 446 |
-
out_real[offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[offset + idx]);
|
| 447 |
-
}else{
|
| 448 |
-
out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
|
| 449 |
-
}
|
| 450 |
-
}
|
| 451 |
-
}
|
| 452 |
-
|
| 453 |
-
__syncthreads();
|
| 454 |
-
}
|
| 455 |
-
}
|
| 456 |
-
|
| 457 |
-
__global__ void butterfly_ifft_bf16_cuda_kernel_16(
|
| 458 |
-
const __nv_bfloat162 *__restrict__ x_real,
|
| 459 |
-
const __nv_bfloat162 *__restrict__ x_imag,
|
| 460 |
-
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 461 |
-
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 462 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 463 |
-
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 464 |
-
__nv_bfloat162 *__restrict__ out_real,
|
| 465 |
-
__nv_bfloat162 *__restrict__ out_gate,
|
| 466 |
-
uint B,
|
| 467 |
-
uint H,
|
| 468 |
-
int N)
|
| 469 |
-
{
|
| 470 |
-
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 471 |
-
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 472 |
-
int idx;
|
| 473 |
-
int shared_offset;
|
| 474 |
-
const int B_Y = blockDim.y;
|
| 475 |
-
const int n = N / B_Y;
|
| 476 |
-
|
| 477 |
-
__shared__ __nv_bfloat16 x_real_shared[16 * 64];
|
| 478 |
-
__shared__ __nv_bfloat16 x_imag_shared[16 * 64];
|
| 479 |
-
__shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
|
| 480 |
-
__shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
|
| 481 |
-
__shared__ float out_real_shared[16 * 64];
|
| 482 |
-
|
| 483 |
-
// #pragma unroll
|
| 484 |
-
for (int i = 0; i < n; i++)
|
| 485 |
-
{
|
| 486 |
-
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 487 |
-
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 488 |
-
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 489 |
-
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 490 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 491 |
-
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 492 |
-
}
|
| 493 |
-
|
| 494 |
-
__syncthreads();
|
| 495 |
-
|
| 496 |
-
if (threadIdx.y < 4)
|
| 497 |
-
{
|
| 498 |
-
__nv_bfloat16 tmp_real, tmp_imag;
|
| 499 |
-
|
| 500 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 501 |
-
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 502 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 503 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 504 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
|
| 505 |
-
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
|
| 506 |
-
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 507 |
-
|
| 508 |
-
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 509 |
-
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 510 |
-
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 511 |
-
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 512 |
-
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 513 |
-
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 517 |
-
{
|
| 518 |
-
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 519 |
-
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 520 |
-
b_frag_real.x[k] = tmp_real;
|
| 521 |
-
b_frag_imag.x[k] = tmp_imag;
|
| 522 |
-
}
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 527 |
-
|
| 528 |
-
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 529 |
-
|
| 530 |
-
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 531 |
-
acc_frag_real.x[k] = - acc_frag_real.x[k];
|
| 532 |
-
}
|
| 533 |
-
|
| 534 |
-
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 535 |
-
|
| 536 |
-
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 537 |
-
|
| 538 |
-
}
|
| 539 |
-
|
| 540 |
-
__syncthreads();
|
| 541 |
-
|
| 542 |
-
#pragma unroll
|
| 543 |
-
for (int i = 0; i < n; i++)
|
| 544 |
-
{
|
| 545 |
-
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 546 |
-
if(out_gate != nullptr){
|
| 547 |
-
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
|
| 548 |
-
}else{
|
| 549 |
-
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 550 |
-
}
|
| 551 |
-
}
|
| 552 |
-
}
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
torch::Tensor butterfly_ifft_bf16_cuda(
|
| 556 |
-
torch::Tensor x_real,
|
| 557 |
-
torch::Tensor x_imag,
|
| 558 |
-
torch::Tensor d_f_real,
|
| 559 |
-
torch::Tensor d_f_imag,
|
| 560 |
-
torch::Tensor twiddle_factors_real,
|
| 561 |
-
torch::Tensor twiddle_factors_imag,
|
| 562 |
-
std::optional<at::Tensor> out_gate = std::nullopt
|
| 563 |
-
)
|
| 564 |
-
{
|
| 565 |
-
|
| 566 |
-
uint B = x_real.size(0);
|
| 567 |
-
uint H = x_real.size(1);
|
| 568 |
-
// uint m = x.size(1);
|
| 569 |
-
|
| 570 |
-
// const int TILE_SIZE = 16;
|
| 571 |
-
|
| 572 |
-
dim3 gridDim;
|
| 573 |
-
dim3 blockDim;
|
| 574 |
-
|
| 575 |
-
uint N = x_real.size(2);
|
| 576 |
-
uint M = x_real.size(3);
|
| 577 |
-
gridDim.y = B;
|
| 578 |
-
|
| 579 |
-
blockDim.x = 32;
|
| 580 |
-
blockDim.y = 4;
|
| 581 |
-
|
| 582 |
-
torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
//set blockDims
|
| 586 |
-
switch(N){
|
| 587 |
-
case 128:
|
| 588 |
-
blockDim.x = 32;
|
| 589 |
-
blockDim.y = 8;
|
| 590 |
-
break;
|
| 591 |
-
default:
|
| 592 |
-
blockDim.x = 32;
|
| 593 |
-
blockDim.y = 4;
|
| 594 |
-
break;
|
| 595 |
-
}
|
| 596 |
-
|
| 597 |
-
//set gridDim.x
|
| 598 |
-
switch(N){
|
| 599 |
-
case 128:
|
| 600 |
-
switch (M){
|
| 601 |
-
case 16384:
|
| 602 |
-
gridDim.x = 128;
|
| 603 |
-
break;
|
| 604 |
-
case 8192:
|
| 605 |
-
gridDim.x = 64;
|
| 606 |
-
break;
|
| 607 |
-
case 4096:
|
| 608 |
-
gridDim.x = 32;
|
| 609 |
-
break;
|
| 610 |
-
default:
|
| 611 |
-
gridDim.x = 256;
|
| 612 |
-
break;
|
| 613 |
-
}
|
| 614 |
-
break;
|
| 615 |
-
default:
|
| 616 |
-
switch (M){
|
| 617 |
-
case 16384:
|
| 618 |
-
gridDim.x = 256;
|
| 619 |
-
break;
|
| 620 |
-
case 8192:
|
| 621 |
-
gridDim.x = 128;
|
| 622 |
-
break;
|
| 623 |
-
case 4096:
|
| 624 |
-
gridDim.x = 64;
|
| 625 |
-
break;
|
| 626 |
-
default:
|
| 627 |
-
gridDim.x = 512;
|
| 628 |
-
break;
|
| 629 |
-
}
|
| 630 |
-
break;
|
| 631 |
-
}
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
switch (N)
|
| 635 |
-
{
|
| 636 |
-
case 16:
|
| 637 |
-
gridDim.z = H;
|
| 638 |
-
butterfly_ifft_bf16_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 639 |
-
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 640 |
-
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 641 |
-
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 642 |
-
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 643 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 644 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 645 |
-
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 646 |
-
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 647 |
-
B,
|
| 648 |
-
H,
|
| 649 |
-
N);
|
| 650 |
-
break;
|
| 651 |
-
|
| 652 |
-
case 32:
|
| 653 |
-
gridDim.z = H;
|
| 654 |
-
butterfly_ifft_bf16_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 655 |
-
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 656 |
-
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 657 |
-
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 658 |
-
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 659 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 660 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 661 |
-
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 662 |
-
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 663 |
-
B,
|
| 664 |
-
H,
|
| 665 |
-
N);
|
| 666 |
-
break;
|
| 667 |
-
case 64:
|
| 668 |
-
gridDim.z = H / 16;
|
| 669 |
-
cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 670 |
-
butterfly_ifft_bf16_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
|
| 671 |
-
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 672 |
-
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 673 |
-
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 674 |
-
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 675 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 676 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 677 |
-
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 678 |
-
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 679 |
-
B,
|
| 680 |
-
H,
|
| 681 |
-
N);
|
| 682 |
-
break;
|
| 683 |
-
|
| 684 |
-
case 128:
|
| 685 |
-
gridDim.z = H / 16;
|
| 686 |
-
cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 687 |
-
butterfly_ifft_bf16_cuda_kernel_128<<<gridDim, blockDim, 65536 * 2>>>(
|
| 688 |
-
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 689 |
-
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 690 |
-
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 691 |
-
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 692 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 693 |
-
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 694 |
-
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 695 |
-
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 696 |
-
B,
|
| 697 |
-
H,
|
| 698 |
-
N);
|
| 699 |
-
break;
|
| 700 |
-
default:
|
| 701 |
-
printf("Not implemented\n");
|
| 702 |
-
}
|
| 703 |
-
|
| 704 |
-
return out;
|
| 705 |
-
}
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include <cuda_runtime.h>
|
| 11 |
+
#include "shared.h"
|
| 12 |
+
|
| 13 |
+
using namespace nvcuda;
|
| 14 |
+
|
| 15 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_64(
|
| 16 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 17 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 18 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 24 |
+
uint B,
|
| 25 |
+
uint H,
|
| 26 |
+
int N)
|
| 27 |
+
{
|
| 28 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 29 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 30 |
+
int idx;
|
| 31 |
+
int shared_offset;
|
| 32 |
+
const int B_Y = blockDim.y;
|
| 33 |
+
const int n = N / B_Y;
|
| 34 |
+
|
| 35 |
+
extern __shared__ __nv_bfloat16 x_real_shared[];
|
| 36 |
+
__nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
|
| 37 |
+
__nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
|
| 38 |
+
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 39 |
+
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 40 |
+
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 41 |
+
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 42 |
+
|
| 43 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 44 |
+
|
| 45 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4][4];
|
| 46 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4][4];
|
| 47 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 48 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 49 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
|
| 50 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
|
| 51 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 52 |
+
|
| 53 |
+
// #pragma unroll
|
| 54 |
+
for (int i = 0; i < n; i++)
|
| 55 |
+
{
|
| 56 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 57 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 58 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 59 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 60 |
+
|
| 61 |
+
// #pragma unroll
|
| 62 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 63 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 64 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
__syncthreads();
|
| 68 |
+
|
| 69 |
+
for (int i = 0; i < 4; i++)
|
| 70 |
+
{
|
| 71 |
+
#pragma unroll
|
| 72 |
+
for (int j = 0; j < 4; j++)
|
| 73 |
+
{
|
| 74 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 75 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 76 |
+
}
|
| 77 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 78 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
for (int t = 0; t < 16; t++)
|
| 82 |
+
{
|
| 83 |
+
|
| 84 |
+
for (int i = 0; i < n; i++)
|
| 85 |
+
{
|
| 86 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 87 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 88 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 89 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
__syncthreads();
|
| 93 |
+
|
| 94 |
+
for (int i = 0; i < 4; i++)
|
| 95 |
+
{
|
| 96 |
+
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 97 |
+
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
for (int j = 0; j < 4; j++)
|
| 101 |
+
{
|
| 102 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 103 |
+
{
|
| 104 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 105 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 106 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 107 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
for (int i = 0; i < 4; i++)
|
| 112 |
+
{
|
| 113 |
+
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 114 |
+
|
| 115 |
+
// bd
|
| 116 |
+
#pragma unroll
|
| 117 |
+
for (int k = 0; k < 4; k++)
|
| 118 |
+
{
|
| 119 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 123 |
+
{
|
| 124 |
+
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
for (int i = 0; i < 4; i++)
|
| 129 |
+
{
|
| 130 |
+
// ac - bd
|
| 131 |
+
#pragma unroll
|
| 132 |
+
for (int k = 0; k < 4; k++)
|
| 133 |
+
{
|
| 134 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
#pragma unroll
|
| 139 |
+
for (int i = 0; i < 4; i++)
|
| 140 |
+
{
|
| 141 |
+
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
__syncthreads();
|
| 145 |
+
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for (int i = 0; i < n; i++)
|
| 148 |
+
{
|
| 149 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 150 |
+
if(out_gate != nullptr){
|
| 151 |
+
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]); ;
|
| 152 |
+
}else{
|
| 153 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
__syncthreads();
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_32(
|
| 162 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 163 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 164 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 165 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 166 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 167 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 168 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 169 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 170 |
+
uint B,
|
| 171 |
+
uint H,
|
| 172 |
+
int N)
|
| 173 |
+
{
|
| 174 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 175 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 176 |
+
int idx;
|
| 177 |
+
int shared_offset;
|
| 178 |
+
const int B_Y = blockDim.y;
|
| 179 |
+
const int n = N / B_Y;
|
| 180 |
+
|
| 181 |
+
__shared__ __nv_bfloat16 x_real_shared[32 * 64];
|
| 182 |
+
__shared__ __nv_bfloat16 x_imag_shared[32 * 64];
|
| 183 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 184 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 185 |
+
__shared__ float out_real_shared[32 * 64];
|
| 186 |
+
|
| 187 |
+
// #pragma unroll
|
| 188 |
+
for (int i = 0; i < n; i++)
|
| 189 |
+
{
|
| 190 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 191 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 192 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 193 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 194 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 195 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
__syncthreads();
|
| 199 |
+
|
| 200 |
+
if (threadIdx.y < N / 16)
|
| 201 |
+
{
|
| 202 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 203 |
+
|
| 204 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 205 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 206 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 207 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 208 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
|
| 209 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
|
| 210 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 211 |
+
|
| 212 |
+
int t = threadIdx.y * 32;
|
| 213 |
+
|
| 214 |
+
for (int i = 0; i < 2; i++)
|
| 215 |
+
{
|
| 216 |
+
for (int j = 0; j < 2; j++)
|
| 217 |
+
{
|
| 218 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 219 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 220 |
+
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 221 |
+
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 222 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 223 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
for (int i = 0; i < 2; i++)
|
| 228 |
+
{
|
| 229 |
+
for (int j = 0; j < 2; j++)
|
| 230 |
+
{
|
| 231 |
+
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 232 |
+
{
|
| 233 |
+
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 234 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 235 |
+
b_frag_real[i][j].x[k] = tmp_real;
|
| 236 |
+
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
for (int i = 0; i < 2; i++)
|
| 242 |
+
{
|
| 243 |
+
for (int j = 0; j < 2; j++)
|
| 244 |
+
{
|
| 245 |
+
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 246 |
+
|
| 247 |
+
// bd
|
| 248 |
+
for (int k = 0; k < 2; k++)
|
| 249 |
+
{
|
| 250 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 254 |
+
{
|
| 255 |
+
acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
for (int i = 0; i < 2; i++)
|
| 261 |
+
{
|
| 262 |
+
for (int j = 0; j < 2; j++)
|
| 263 |
+
{
|
| 264 |
+
// ac - bd
|
| 265 |
+
for (int k = 0; k < 2; k++)
|
| 266 |
+
{
|
| 267 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
for (int i = 0; i < 2; i++)
|
| 273 |
+
{
|
| 274 |
+
for (int j = 0; j < 2; j++)
|
| 275 |
+
{
|
| 276 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
__syncthreads();
|
| 282 |
+
|
| 283 |
+
#pragma unroll
|
| 284 |
+
for (int i = 0; i < n; i++)
|
| 285 |
+
{
|
| 286 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 287 |
+
if(out_gate != nullptr){
|
| 288 |
+
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
|
| 289 |
+
}else{
|
| 290 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_128(
|
| 297 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 298 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 299 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 300 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 301 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 302 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 303 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 304 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 305 |
+
uint B,
|
| 306 |
+
uint H,
|
| 307 |
+
int N)
|
| 308 |
+
{
|
| 309 |
+
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 310 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 311 |
+
int idx;
|
| 312 |
+
int shared_offset;
|
| 313 |
+
const int B_Y = blockDim.y;
|
| 314 |
+
const int n = N / B_Y;
|
| 315 |
+
|
| 316 |
+
extern __shared__ __nv_bfloat16 real_shared[];
|
| 317 |
+
__nv_bfloat16 *imag_shared = &real_shared[128 * 128];
|
| 318 |
+
__nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
|
| 319 |
+
__nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
|
| 320 |
+
|
| 321 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 322 |
+
|
| 323 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[8][8];
|
| 324 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 325 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 326 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
|
| 327 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
|
| 328 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 329 |
+
|
| 330 |
+
for (int i = 0; i < n; i++)
|
| 331 |
+
{
|
| 332 |
+
for(int j=0; j< 2; j++){
|
| 333 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 334 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
|
| 335 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
for (int i = 0; i < n; i++)
|
| 340 |
+
{
|
| 341 |
+
for(int j=0; j< 2; j++){
|
| 342 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 343 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 344 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 345 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
__syncthreads();
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
for (int i = 0; i < 8; i++){
|
| 353 |
+
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 354 |
+
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
__syncthreads();
|
| 358 |
+
|
| 359 |
+
for (int t = 0; t < 16; t++)
|
| 360 |
+
{
|
| 361 |
+
for (int i = 0; i < 8; i++){
|
| 362 |
+
for (int j = 0; j < 8; j++){
|
| 363 |
+
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
for (int i = 0; i < n; i++)
|
| 368 |
+
{
|
| 369 |
+
for(int j=0; j< 2; j++){
|
| 370 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 371 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 372 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[offset + idx];
|
| 373 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[offset + idx];
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
__syncthreads();
|
| 378 |
+
|
| 379 |
+
for (int i = 0; i < 8; i++)
|
| 380 |
+
{
|
| 381 |
+
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 382 |
+
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
for (int j = 0; j < 8; j++)
|
| 387 |
+
{
|
| 388 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 389 |
+
{
|
| 390 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 391 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 392 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 393 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
for (int i = 0; i < 8; i++)
|
| 398 |
+
{
|
| 399 |
+
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 400 |
+
|
| 401 |
+
// bd
|
| 402 |
+
#pragma unroll
|
| 403 |
+
for (int k = 0; k < 8; k++)
|
| 404 |
+
{
|
| 405 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 409 |
+
{
|
| 410 |
+
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
for (int i = 0; i < 8; i++){
|
| 415 |
+
for (int j = 0; j < 8; j++){
|
| 416 |
+
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
for (int i = 0; i < 8; i++)
|
| 421 |
+
{
|
| 422 |
+
// ac - bd
|
| 423 |
+
#pragma unroll
|
| 424 |
+
for (int k = 0; k < 8; k++)
|
| 425 |
+
{
|
| 426 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
#pragma unroll
|
| 431 |
+
for (int i = 0; i < 8; i++)
|
| 432 |
+
{
|
| 433 |
+
//wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 434 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
__syncthreads();
|
| 438 |
+
|
| 439 |
+
#pragma unroll
|
| 440 |
+
for (int i = 0; i < n; i++)
|
| 441 |
+
{
|
| 442 |
+
for(int j=0; j< 2; j++){
|
| 443 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 444 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 445 |
+
if(out_gate != nullptr){
|
| 446 |
+
out_real[offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[offset + idx]);
|
| 447 |
+
}else{
|
| 448 |
+
out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
__syncthreads();
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_16(
|
| 458 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 459 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 460 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 461 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 462 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 463 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 464 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 465 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 466 |
+
uint B,
|
| 467 |
+
uint H,
|
| 468 |
+
int N)
|
| 469 |
+
{
|
| 470 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 471 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 472 |
+
int idx;
|
| 473 |
+
int shared_offset;
|
| 474 |
+
const int B_Y = blockDim.y;
|
| 475 |
+
const int n = N / B_Y;
|
| 476 |
+
|
| 477 |
+
__shared__ __nv_bfloat16 x_real_shared[16 * 64];
|
| 478 |
+
__shared__ __nv_bfloat16 x_imag_shared[16 * 64];
|
| 479 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
|
| 480 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
|
| 481 |
+
__shared__ float out_real_shared[16 * 64];
|
| 482 |
+
|
| 483 |
+
// #pragma unroll
|
| 484 |
+
for (int i = 0; i < n; i++)
|
| 485 |
+
{
|
| 486 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 487 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 488 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 489 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 490 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 491 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
__syncthreads();
|
| 495 |
+
|
| 496 |
+
if (threadIdx.y < 4)
|
| 497 |
+
{
|
| 498 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 499 |
+
|
| 500 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 501 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 502 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 503 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 504 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
|
| 505 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
|
| 506 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 507 |
+
|
| 508 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 509 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 510 |
+
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 511 |
+
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 512 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 513 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 517 |
+
{
|
| 518 |
+
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 519 |
+
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 520 |
+
b_frag_real.x[k] = tmp_real;
|
| 521 |
+
b_frag_imag.x[k] = tmp_imag;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 527 |
+
|
| 528 |
+
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 529 |
+
|
| 530 |
+
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 531 |
+
acc_frag_real.x[k] = - acc_frag_real.x[k];
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 535 |
+
|
| 536 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 537 |
+
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
__syncthreads();
|
| 541 |
+
|
| 542 |
+
#pragma unroll
|
| 543 |
+
for (int i = 0; i < n; i++)
|
| 544 |
+
{
|
| 545 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 546 |
+
if(out_gate != nullptr){
|
| 547 |
+
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
|
| 548 |
+
}else{
|
| 549 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
torch::Tensor butterfly_ifft_bf16_cuda(
|
| 556 |
+
torch::Tensor x_real,
|
| 557 |
+
torch::Tensor x_imag,
|
| 558 |
+
torch::Tensor d_f_real,
|
| 559 |
+
torch::Tensor d_f_imag,
|
| 560 |
+
torch::Tensor twiddle_factors_real,
|
| 561 |
+
torch::Tensor twiddle_factors_imag,
|
| 562 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 563 |
+
)
|
| 564 |
+
{
|
| 565 |
+
|
| 566 |
+
uint B = x_real.size(0);
|
| 567 |
+
uint H = x_real.size(1);
|
| 568 |
+
// uint m = x.size(1);
|
| 569 |
+
|
| 570 |
+
// const int TILE_SIZE = 16;
|
| 571 |
+
|
| 572 |
+
dim3 gridDim;
|
| 573 |
+
dim3 blockDim;
|
| 574 |
+
|
| 575 |
+
uint N = x_real.size(2);
|
| 576 |
+
uint M = x_real.size(3);
|
| 577 |
+
gridDim.y = B;
|
| 578 |
+
|
| 579 |
+
blockDim.x = 32;
|
| 580 |
+
blockDim.y = 4;
|
| 581 |
+
|
| 582 |
+
torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
//set blockDims
|
| 586 |
+
switch(N){
|
| 587 |
+
case 128:
|
| 588 |
+
blockDim.x = 32;
|
| 589 |
+
blockDim.y = 8;
|
| 590 |
+
break;
|
| 591 |
+
default:
|
| 592 |
+
blockDim.x = 32;
|
| 593 |
+
blockDim.y = 4;
|
| 594 |
+
break;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
//set gridDim.x
|
| 598 |
+
switch(N){
|
| 599 |
+
case 128:
|
| 600 |
+
switch (M){
|
| 601 |
+
case 16384:
|
| 602 |
+
gridDim.x = 128;
|
| 603 |
+
break;
|
| 604 |
+
case 8192:
|
| 605 |
+
gridDim.x = 64;
|
| 606 |
+
break;
|
| 607 |
+
case 4096:
|
| 608 |
+
gridDim.x = 32;
|
| 609 |
+
break;
|
| 610 |
+
default:
|
| 611 |
+
gridDim.x = 256;
|
| 612 |
+
break;
|
| 613 |
+
}
|
| 614 |
+
break;
|
| 615 |
+
default:
|
| 616 |
+
switch (M){
|
| 617 |
+
case 16384:
|
| 618 |
+
gridDim.x = 256;
|
| 619 |
+
break;
|
| 620 |
+
case 8192:
|
| 621 |
+
gridDim.x = 128;
|
| 622 |
+
break;
|
| 623 |
+
case 4096:
|
| 624 |
+
gridDim.x = 64;
|
| 625 |
+
break;
|
| 626 |
+
default:
|
| 627 |
+
gridDim.x = 512;
|
| 628 |
+
break;
|
| 629 |
+
}
|
| 630 |
+
break;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
switch (N)
|
| 635 |
+
{
|
| 636 |
+
case 16:
|
| 637 |
+
gridDim.z = H;
|
| 638 |
+
butterfly_ifft_bf16_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 639 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 640 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 641 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 642 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 643 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 644 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 645 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 646 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 647 |
+
B,
|
| 648 |
+
H,
|
| 649 |
+
N);
|
| 650 |
+
break;
|
| 651 |
+
|
| 652 |
+
case 32:
|
| 653 |
+
gridDim.z = H;
|
| 654 |
+
butterfly_ifft_bf16_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 655 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 656 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 657 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 658 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 659 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 660 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 661 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 662 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 663 |
+
B,
|
| 664 |
+
H,
|
| 665 |
+
N);
|
| 666 |
+
break;
|
| 667 |
+
case 64:
|
| 668 |
+
gridDim.z = H / 16;
|
| 669 |
+
cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 670 |
+
butterfly_ifft_bf16_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
|
| 671 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 672 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 673 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 674 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 675 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 676 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 677 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 678 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 679 |
+
B,
|
| 680 |
+
H,
|
| 681 |
+
N);
|
| 682 |
+
break;
|
| 683 |
+
|
| 684 |
+
case 128:
|
| 685 |
+
gridDim.z = H / 16;
|
| 686 |
+
cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 687 |
+
butterfly_ifft_bf16_cuda_kernel_128<<<gridDim, blockDim, 65536 * 2>>>(
|
| 688 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 689 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 690 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 691 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 692 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 693 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 694 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 695 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 696 |
+
B,
|
| 697 |
+
H,
|
| 698 |
+
N);
|
| 699 |
+
break;
|
| 700 |
+
default:
|
| 701 |
+
printf("Not implemented\n");
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
return out;
|
| 705 |
+
}
|