icarus112's picture
Upload folder using huggingface_hub
7875879 verified
"""
HTM torch wrapper around the pyo3 ``htm_rust`` crate.
Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to
``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream
and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly
score. HTM learning is Hebbian (not gradient), so the wrapper runs under
``torch.no_grad()``. Downstream layers carry gradients back to the embedding
via their own learnable projection from the binary column output.
Per-sequence state semantics
---------------------------
Training-time forward passes are independent windows of tokens (re-sampled
every step), so carrying TM state across calls would mix unrelated contexts.
This layer calls ``reset()`` on every region at the top of ``forward``; the
TM learns within-window temporal patterns only. Users that want cross-window
continuity (e.g. eval over a long document) should instead construct the
layer and drive ``step_stream`` themselves (not implemented here; the
single-forward contract is sufficient for the autoresearch loop).
Device handling
---------------
``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a
``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back
to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this
copy is small compared to the SP/TM compute.
"""
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import torch
import torch.nn as nn
import htm_rust
def _resolve_htm_rust_extension():
"""Return real htm_rust extension, not the repo source-dir namespace.
In repo/HF runtime cwd, `/workspace/feather/htm_rust/` can shadow the
installed PyO3 extension wheel. That yields a namespace module with no
`HTMRegion`, which previously triggered invalid no-learning paths. Search
site-packages explicitly and load the compiled extension if the first import
is only the source tree.
"""
if hasattr(htm_rust, "HTMRegion"):
return htm_rust
import importlib.util
import site
import sysconfig
from pathlib import Path
search_roots = []
for getter in (site.getsitepackages, lambda: [site.getusersitepackages()]):
try:
search_roots.extend(Path(p) for p in getter())
except Exception:
pass
try:
search_roots.append(Path(sysconfig.get_paths()["purelib"]))
search_roots.append(Path(sysconfig.get_paths()["platlib"]))
except Exception:
pass
seen = set()
for root in search_roots:
if root in seen or not root.exists():
continue
seen.add(root)
for so_path in sorted(root.glob("htm_rust*.so")):
spec = importlib.util.spec_from_file_location("htm_rust", so_path)
if spec is None or spec.loader is None:
continue
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[union-attr]
if hasattr(mod, "HTMRegion"):
return mod
raise RuntimeError(
"real htm_rust extension is not importable: source directory shadowed "
"the wheel and no site-packages htm_rust*.so with HTMRegion was found"
)
htm_rust = _resolve_htm_rust_extension()
# step_many releases the GIL for the whole pass, so multiple threads can
# truly run regions in parallel — wall-clock scales with B up to CPU cores.
_HTM_REGION_CLS = getattr(htm_rust, "HTMRegion", None)
_HTM_HAS_STEP_MANY = _HTM_REGION_CLS is not None and hasattr(_HTM_REGION_CLS, "step_many")
# GPU backend: built with `maturin develop --features gpu`. One CUDA region
# per batch slot, persistent device state for SP synapses. Transparent
# fallback to CPU when not available.
_HTM_GPU_REGION_CLS = getattr(htm_rust, "HTMRegionGpu", None)
_HTM_HAS_GPU = _HTM_GPU_REGION_CLS is not None
# Zero-copy CUDA path: consumes torch CUDA tensors directly via the
# __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
# and the D2H of outputs. Huge win when the input SDR already lives on GPU
# (which is the train.py hot path — retina is a device buffer).
_HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_GPU_REGION_CLS, "step_many_cuda")
# Fused megakernel path: collapses all T timesteps + SP + TM into a single
# CUDA launch per forward. Replaces global top-K with per-column threshold
# inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
# Opt-in via env var (default on when available).
import os as _os_fused
_HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_GPU_REGION_CLS, "step_many_fused_cuda")
_HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
_HTM_USE_BATCHED_FUSED = _HTM_USE_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_BATCHED_FUSED", "1")))
def _resolve_use_gpu(use_gpu: bool | None, *, cuda_available: bool) -> bool:
"""Resolve HTMLayer GPU use with HYDRA_HTM_USE_GPU override."""
htm_use_gpu_env = _os_fused.environ.get("HYDRA_HTM_USE_GPU", "auto").lower()
if htm_use_gpu_env in {"0", "false", "no", "cpu"}:
return False
if htm_use_gpu_env in {"1", "true", "yes", "gpu"}:
return True
if use_gpu is None:
return _HTM_HAS_GPU and cuda_available
return bool(use_gpu)
class HTMLayer(nn.Module):
"""Batched torch wrapper around ``htm_rust.HTMRegion``.
One independent region per batch slot so temporal memory learns
sequence-local patterns without cross-batch bleed. Regions grow
lazily if a larger batch shows up.
Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are
the binary active-column mask (float32 0/1) and the last channel is
the per-timestep anomaly score in [0, 1].
"""
def __init__(
self,
input_bits: int = 16384,
n_columns: int = 2048,
cells_per_column: int = 32,
batch_size: int = 1,
seed: int = 42,
learn: bool = True,
reset_each_forward: bool = True,
use_gpu: bool | None = None,
) -> None:
super().__init__()
self.input_bits = input_bits
self.n_columns = n_columns
self.cells_per_column = cells_per_column
self.learn = learn
self.learning_enabled = True
self.reset_each_forward = reset_each_forward
self._seed_base = seed
# Learn gating: HTM learn kernels (tm_punish, tm_learn_reinforce, tm_grow)
# are 56% of total HTM CUDA time. Gating them to run every N forwards
# instead of every forward cuts HTM cost ~2x. Hebbian learning still
# converges since the EMA accumulates over many calls. Env:
# HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
import os as _os
self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1")))
self._forward_counter = 0
# GPU backend gate. Default: auto-detect — use GPU when the pyo3
# module was built with --features gpu AND CUDA is actually usable.
# HYDRA_FORCE_HTM_CPU=1 is an operational safety valve for paid remote
# canaries when the compiled CUDA HTM backend is present but unstable on
# a specific hardware/runtime combination.
force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
strict_optimal = _os.environ.get("HYDRA_STRICT_OPTIMAL_COMPONENTS", "0") == "1"
if strict_optimal and force_cpu:
raise RuntimeError("HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires GPU HTM; HYDRA_FORCE_HTM_CPU=1 is not allowed.")
if use_gpu is None:
use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
elif use_gpu and not _HTM_HAS_GPU:
raise RuntimeError(
"HTMLayer(use_gpu=True) but htm_rust was not built with "
"--features gpu. Re-run `maturin develop --features gpu`."
)
elif use_gpu and force_cpu:
use_gpu = False
self._use_gpu = bool(use_gpu)
if strict_optimal:
if not self._use_gpu:
raise RuntimeError(
"HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires GPU HTM; "
"htm_rust GPU backend and CUDA must be available."
)
if not (_HTM_USE_FUSED and _HTM_USE_BATCHED_FUSED):
raise RuntimeError(
"HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires fused batched CUDA HTM; "
"set HYDRA_HTM_FUSED=1 and HYDRA_HTM_BATCHED_FUSED=1."
)
if not hasattr(htm_rust, "step_batch_fused_cuda"):
raise RuntimeError(
"HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires htm_rust.step_batch_fused_cuda; "
"the current htm_rust wheel would silently fall back to per-region fused HTM."
)
cls = _HTM_GPU_REGION_CLS if self._use_gpu else _HTM_REGION_CLS
if cls is None:
raise RuntimeError(
"htm_rust does not expose HTMRegion; install/build htm_rust before constructing HTMLayer"
)
self._region_cls = cls
self._regions = [
cls(input_bits, n_columns, cells_per_column, seed + i)
for i in range(batch_size)
]
self.register_buffer("_dummy", torch.zeros(1), persistent=False)
import os as _os
self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16))
def _ensure_regions(self, B: int) -> None:
while len(self._regions) < B:
idx = len(self._regions)
self._regions.append(
self._region_cls(
self.input_bits,
self.n_columns,
self.cells_per_column,
self._seed_base + idx,
)
)
def reset(self) -> None:
"""Clear TM predictive state on every region (keeps SP synapses)."""
for r in self._regions:
r.reset()
def _next_learn_flag(self) -> bool:
"""Return whether this forward may mutate HTM state.
Both synchronous and async HTM paths must use this same gate. Eval,
validation, and factual-probe forwards therefore cannot update the
persistent Hebbian state even if the layer was constructed with
``learn=True``.
"""
self._forward_counter += 1
return bool(
self.learn
and getattr(self, "learning_enabled", True)
and self.training
and (self._forward_counter % self._learn_every == 0)
)
@torch.no_grad()
def forward(self, sdr: torch.Tensor) -> torch.Tensor:
B, T, D = sdr.shape
if D != self.input_bits:
raise ValueError(f"expected input_bits={self.input_bits}, got {D}")
self._ensure_regions(B)
if self.reset_each_forward:
self.reset()
# Learn-gate: run learn kernels only every N forwards (skips 56% of
# HTM CUDA time on skip-forwards; Hebbian EMA still converges) and
# never during eval/validation/factual probes.
learn = self._next_learn_flag()
# Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
# so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
# kernel writes directly into torch-owned CUDA tensors via CAI.
# Gives 5-10x tok/s on train.py vs the numpy path below.
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
# Pick fused (1 launch) or legacy (12*T launches) path.
if _HTM_USE_FUSED:
for b in range(B):
self._regions[b].step_many_fused_cuda(
sdr_u8[b].__cuda_array_interface__,
cols_out[b].__cuda_array_interface__,
anom_out[b].__cuda_array_interface__,
learn,
)
else:
for b in range(B):
self._regions[b].step_many_cuda(
sdr_u8[b].__cuda_array_interface__,
cols_out[b].__cuda_array_interface__,
anom_out[b].__cuda_array_interface__,
learn,
)
# Assemble (B, T, n_cols+1) — keep bf16-friendly float32.
return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1)
# Fallback: CPU / numpy path. Kept for CPU-input case and for
# builds without CAI support.
sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy()
out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32)
def _process_one(b: int) -> None:
region = self._regions[b]
if self._use_gpu:
cols, anom = region.step_many_gpu(sdr_np[b], learn)
out[b, :, : self.n_columns] = cols
out[b, :, self.n_columns] = anom
elif _HTM_HAS_STEP_MANY:
# Single Rust call: T steps with GIL released for the whole pass.
cols, anom = region.step_many(sdr_np[b], learn) # cols (T, n_cols), anom (T,)
out[b, :, : self.n_columns] = cols
out[b, :, self.n_columns] = anom
else:
for t in range(T):
active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn)
out[b, t, : self.n_columns] = active_cols
out[b, t, self.n_columns] = float(anomaly)
if B == 1:
_process_one(0)
elif self._use_gpu:
# GPU regions share the CUDA context; serialise to avoid contention
# for stream 0. Per-region latency is dominated by kernel compute,
# not threadable on a single stream cheaply — future work: one
# CUDA stream per region.
for b in range(B):
_process_one(b)
else:
# Each thread runs in pure Rust under py.allow_threads, so they
# parallelise to wall-clock min(B, CPU_cores).
list(self._htm_pool.map(_process_one, range(B)))
return torch.from_numpy(out).to(sdr.device)
def forward_async(self, sdr: torch.Tensor):
"""Submit HTM work and return a handle awaitable via ``forward_await``.
On the CAI zero-copy path (GPU tensor in, GPU region), the Rust
CUDA kernels are launched on cudarc's internal stream and control
returns **immediately** — no device synchronization. The caller's
next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued
on PyTorch's default stream and can execute while HTM kernels run
on the cudarc stream. ``forward_await`` performs the cross-stream
sync (via ``device_sync``) and assembles the output tensor only
when the result is actually consumed.
For cooperative kernels (``step_many_fused_cuda``) the GPU can only
run one cooperative launch at a time, so kernel-level overlap with
default-stream work is limited. The win is **CPU-side launch
overlap**: instead of the CPU blocking ~10 ms waiting for HTM
before it can even enqueue wte/mamba, it enqueues everything up
front and the GPU executes back-to-back without CPU stalls.
On the legacy CPU/numpy path, work is dispatched to a thread pool
as before."""
B, T, D = sdr.shape
if D != self.input_bits:
raise ValueError(f"expected input_bits={self.input_bits}, got {D}")
self._ensure_regions(B)
if self.reset_each_forward:
self.reset()
learn = self._next_learn_flag()
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
# ONE cooperative kernel launch for all B regions. Breaks past
# the CUDA cooperative-kernel device-level serialization (only
# one cooperative kernel runs at a time). A single launch with
# grid.y = B processes all regions concurrently — ~B× speedup.
# Falls back to sequential dispatch if the batched entry isn't
# available (older htm_rust wheel).
if _HTM_USE_BATCHED_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"):
# Slice self._regions to match B: _ensure_regions may have
# allocated more regions than the current batch size needs
# (e.g. factual eval uses smaller batches than training).
try:
htm_rust.step_batch_fused_cuda(
self._regions[:B],
[sdr_u8[b].__cuda_array_interface__ for b in range(B)],
[cols_out[b].__cuda_array_interface__ for b in range(B)],
[anom_out[b].__cuda_array_interface__ for b in range(B)],
learn,
)
except RuntimeError as _e:
if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
# Batch too large for cooperative grid. Fall back to
# sequential per-region fused launches (each B=1).
for b in range(B):
self._regions[b].step_many_fused_cuda(
sdr_u8[b].__cuda_array_interface__,
cols_out[b].__cuda_array_interface__,
anom_out[b].__cuda_array_interface__,
learn,
)
else:
raise
elif _HTM_USE_FUSED:
for b in range(B):
self._regions[b].step_many_fused_cuda(
sdr_u8[b].__cuda_array_interface__,
cols_out[b].__cuda_array_interface__,
anom_out[b].__cuda_array_interface__,
learn,
)
else:
for b in range(B):
self._regions[b].step_many_cuda(
sdr_u8[b].__cuda_array_interface__,
cols_out[b].__cuda_array_interface__,
anom_out[b].__cuda_array_interface__,
learn,
)
# NO sync here — kernels are in-flight on cudarc's stream.
# forward_await() will sync before the output is consumed.
return {
'cuda_deferred': True,
'cols_out': cols_out,
'anom_out': anom_out,
'region0': self._regions[0],
}
sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy()
out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32)
def _process_one(b):
region = self._regions[b]
if self._use_gpu:
cols, anom = region.step_many_gpu(sdr_np[b], learn)
out[b, :, : self.n_columns] = cols
out[b, :, self.n_columns] = anom
elif _HTM_HAS_STEP_MANY:
cols, anom = region.step_many(sdr_np[b], learn)
out[b, :, : self.n_columns] = cols
out[b, :, self.n_columns] = anom
else:
for t in range(T):
active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn)
out[b, t, : self.n_columns] = active_cols
out[b, t, self.n_columns] = float(anomaly)
fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)])
return {'fut': fut, 'out': out, 'device': sdr.device}
def forward_await(self, handle) -> torch.Tensor:
if handle.get('cuda_deferred'):
# Cross-stream sync: block until cudarc stream finishes HTM
# kernels so the output tensors are safe to read on the
# default stream.
region0 = handle['region0']
if hasattr(region0, "device_sync"):
region0.device_sync()
else:
torch.cuda.synchronize()
cols_out = handle['cols_out']
anom_out = handle['anom_out']
return torch.cat(
(cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1
)
if 'cuda_result' in handle:
return handle['cuda_result']
handle['fut'].result()
return torch.from_numpy(handle['out']).to(handle['device'])
if __name__ == "__main__":
torch.manual_seed(0)
# Smoke test: (B=2, T=4, D=16384) random 2%-sparse SDR
B, T, D = 2, 4, 16384
n_columns = 2048
target_active_in = int(D * 0.02) # 327
layer = HTMLayer(
input_bits=D,
n_columns=n_columns,
cells_per_column=32,
batch_size=B,
seed=42,
learn=True,
)
layer.train()
rng = np.random.default_rng(0)
sdr = np.zeros((B, T, D), dtype=bool)
for b in range(B):
for t in range(T):
idx = rng.choice(D, size=target_active_in, replace=False)
sdr[b, t, idx] = True
sdr_t = torch.from_numpy(sdr)
t0 = time.perf_counter()
out = layer(sdr_t)
dt_first = time.perf_counter() - t0
assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}"
assert out.dtype == torch.float32, f"dtype {out.dtype}"
active_cols = out[..., :n_columns]
anomaly = out[..., n_columns]
col_sums = active_cols.sum(dim=-1) # (B, T)
mean_active = col_sums.float().mean().item()
expected = n_columns * 0.02 # ≈ 40.96
assert 20 <= mean_active <= 60, (
f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})"
)
# t=0 has no TM prediction → anomaly = 1.0 on every batch slot.
assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}"
# Second forward on same (reset) layer: identical shapes, deterministic re-run possible.
t0 = time.perf_counter()
out2 = layer(sdr_t)
dt_second = time.perf_counter() - t0
assert out2.shape == out.shape
# Repeating-sequence anomaly decay check — one region, T=8 repeats of same pattern.
rep_layer = HTMLayer(
input_bits=D,
n_columns=n_columns,
batch_size=1,
seed=7,
learn=True,
)
rep_layer.train()
base = torch.zeros(D, dtype=torch.bool)
idx = rng.choice(D, size=target_active_in, replace=False)
base[idx] = True
rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone()
rep_out = rep_layer(rep)
rep_anom = rep_out[0, :, n_columns]
assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}"
assert rep_anom[-1].item() < rep_anom[0].item(), (
f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}"
)
print("[OK] shape:", tuple(out.shape))
print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})")
print(f"[OK] t=0 anomaly = 1.0 on all batch slots")
print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}")
print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms "
f"on (B={B}, T={T}, D={D})")