| """ |
| HTM torch wrapper around the pyo3 ``htm_rust`` crate. |
| |
| Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to |
| ``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream |
| and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly |
| score. HTM learning is Hebbian (not gradient), so the wrapper runs under |
| ``torch.no_grad()``. Downstream layers carry gradients back to the embedding |
| via their own learnable projection from the binary column output. |
| |
| Per-sequence state semantics |
| --------------------------- |
| Training-time forward passes are independent windows of tokens (re-sampled |
| every step), so carrying TM state across calls would mix unrelated contexts. |
| This layer calls ``reset()`` on every region at the top of ``forward``; the |
| TM learns within-window temporal patterns only. Users that want cross-window |
| continuity (e.g. eval over a long document) should instead construct the |
| layer and drive ``step_stream`` themselves (not implemented here; the |
| single-forward contract is sufficient for the autoresearch loop). |
| |
| Device handling |
| --------------- |
| ``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a |
| ``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back |
| to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this |
| copy is small compared to the SP/TM compute. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import time |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| import htm_rust |
|
|
|
|
| def _resolve_htm_rust_extension(): |
| """Return real htm_rust extension, not the repo source-dir namespace. |
| |
| In repo/HF runtime cwd, `/workspace/feather/htm_rust/` can shadow the |
| installed PyO3 extension wheel. That yields a namespace module with no |
| `HTMRegion`, which previously triggered invalid no-learning paths. Search |
| site-packages explicitly and load the compiled extension if the first import |
| is only the source tree. |
| """ |
| if hasattr(htm_rust, "HTMRegion"): |
| return htm_rust |
| import importlib.util |
| import site |
| import sysconfig |
| from pathlib import Path |
|
|
| search_roots = [] |
| for getter in (site.getsitepackages, lambda: [site.getusersitepackages()]): |
| try: |
| search_roots.extend(Path(p) for p in getter()) |
| except Exception: |
| pass |
| try: |
| search_roots.append(Path(sysconfig.get_paths()["purelib"])) |
| search_roots.append(Path(sysconfig.get_paths()["platlib"])) |
| except Exception: |
| pass |
|
|
| seen = set() |
| for root in search_roots: |
| if root in seen or not root.exists(): |
| continue |
| seen.add(root) |
| for so_path in sorted(root.glob("htm_rust*.so")): |
| spec = importlib.util.spec_from_file_location("htm_rust", so_path) |
| if spec is None or spec.loader is None: |
| continue |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| if hasattr(mod, "HTMRegion"): |
| return mod |
| raise RuntimeError( |
| "real htm_rust extension is not importable: source directory shadowed " |
| "the wheel and no site-packages htm_rust*.so with HTMRegion was found" |
| ) |
|
|
|
|
| htm_rust = _resolve_htm_rust_extension() |
| |
| |
| _HTM_REGION_CLS = getattr(htm_rust, "HTMRegion", None) |
| _HTM_HAS_STEP_MANY = _HTM_REGION_CLS is not None and hasattr(_HTM_REGION_CLS, "step_many") |
| |
| |
| |
| _HTM_GPU_REGION_CLS = getattr(htm_rust, "HTMRegionGpu", None) |
| _HTM_HAS_GPU = _HTM_GPU_REGION_CLS is not None |
| |
| |
| |
| |
| _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_GPU_REGION_CLS, "step_many_cuda") |
| |
| |
| |
| |
| import os as _os_fused |
| _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_GPU_REGION_CLS, "step_many_fused_cuda") |
| _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1"))) |
| _HTM_USE_BATCHED_FUSED = _HTM_USE_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_BATCHED_FUSED", "1"))) |
|
|
|
|
| def _resolve_use_gpu(use_gpu: bool | None, *, cuda_available: bool) -> bool: |
| """Resolve HTMLayer GPU use with HYDRA_HTM_USE_GPU override.""" |
| htm_use_gpu_env = _os_fused.environ.get("HYDRA_HTM_USE_GPU", "auto").lower() |
| if htm_use_gpu_env in {"0", "false", "no", "cpu"}: |
| return False |
| if htm_use_gpu_env in {"1", "true", "yes", "gpu"}: |
| return True |
| if use_gpu is None: |
| return _HTM_HAS_GPU and cuda_available |
| return bool(use_gpu) |
|
|
|
|
| class HTMLayer(nn.Module): |
| """Batched torch wrapper around ``htm_rust.HTMRegion``. |
| |
| One independent region per batch slot so temporal memory learns |
| sequence-local patterns without cross-batch bleed. Regions grow |
| lazily if a larger batch shows up. |
| |
| Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are |
| the binary active-column mask (float32 0/1) and the last channel is |
| the per-timestep anomaly score in [0, 1]. |
| """ |
|
|
| def __init__( |
| self, |
| input_bits: int = 16384, |
| n_columns: int = 2048, |
| cells_per_column: int = 32, |
| batch_size: int = 1, |
| seed: int = 42, |
| learn: bool = True, |
| reset_each_forward: bool = True, |
| use_gpu: bool | None = None, |
| ) -> None: |
| super().__init__() |
| self.input_bits = input_bits |
| self.n_columns = n_columns |
| self.cells_per_column = cells_per_column |
| self.learn = learn |
| self.learning_enabled = True |
| self.reset_each_forward = reset_each_forward |
| self._seed_base = seed |
| |
| |
| |
| |
| |
| import os as _os |
| self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))) |
| self._forward_counter = 0 |
| |
| |
| |
| |
| |
| force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1" |
| strict_optimal = _os.environ.get("HYDRA_STRICT_OPTIMAL_COMPONENTS", "0") == "1" |
| if strict_optimal and force_cpu: |
| raise RuntimeError("HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires GPU HTM; HYDRA_FORCE_HTM_CPU=1 is not allowed.") |
| if use_gpu is None: |
| use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available() |
| elif use_gpu and not _HTM_HAS_GPU: |
| raise RuntimeError( |
| "HTMLayer(use_gpu=True) but htm_rust was not built with " |
| "--features gpu. Re-run `maturin develop --features gpu`." |
| ) |
| elif use_gpu and force_cpu: |
| use_gpu = False |
| self._use_gpu = bool(use_gpu) |
| if strict_optimal: |
| if not self._use_gpu: |
| raise RuntimeError( |
| "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires GPU HTM; " |
| "htm_rust GPU backend and CUDA must be available." |
| ) |
| if not (_HTM_USE_FUSED and _HTM_USE_BATCHED_FUSED): |
| raise RuntimeError( |
| "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires fused batched CUDA HTM; " |
| "set HYDRA_HTM_FUSED=1 and HYDRA_HTM_BATCHED_FUSED=1." |
| ) |
| if not hasattr(htm_rust, "step_batch_fused_cuda"): |
| raise RuntimeError( |
| "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires htm_rust.step_batch_fused_cuda; " |
| "the current htm_rust wheel would silently fall back to per-region fused HTM." |
| ) |
| cls = _HTM_GPU_REGION_CLS if self._use_gpu else _HTM_REGION_CLS |
| if cls is None: |
| raise RuntimeError( |
| "htm_rust does not expose HTMRegion; install/build htm_rust before constructing HTMLayer" |
| ) |
| self._region_cls = cls |
| self._regions = [ |
| cls(input_bits, n_columns, cells_per_column, seed + i) |
| for i in range(batch_size) |
| ] |
| self.register_buffer("_dummy", torch.zeros(1), persistent=False) |
| import os as _os |
| self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16)) |
|
|
| def _ensure_regions(self, B: int) -> None: |
| while len(self._regions) < B: |
| idx = len(self._regions) |
| self._regions.append( |
| self._region_cls( |
| self.input_bits, |
| self.n_columns, |
| self.cells_per_column, |
| self._seed_base + idx, |
| ) |
| ) |
|
|
| def reset(self) -> None: |
| """Clear TM predictive state on every region (keeps SP synapses).""" |
| for r in self._regions: |
| r.reset() |
|
|
| def _next_learn_flag(self) -> bool: |
| """Return whether this forward may mutate HTM state. |
| |
| Both synchronous and async HTM paths must use this same gate. Eval, |
| validation, and factual-probe forwards therefore cannot update the |
| persistent Hebbian state even if the layer was constructed with |
| ``learn=True``. |
| """ |
| self._forward_counter += 1 |
| return bool( |
| self.learn |
| and getattr(self, "learning_enabled", True) |
| and self.training |
| and (self._forward_counter % self._learn_every == 0) |
| ) |
|
|
| @torch.no_grad() |
| def forward(self, sdr: torch.Tensor) -> torch.Tensor: |
| B, T, D = sdr.shape |
| if D != self.input_bits: |
| raise ValueError(f"expected input_bits={self.input_bits}, got {D}") |
| self._ensure_regions(B) |
| if self.reset_each_forward: |
| self.reset() |
|
|
| |
| |
| |
| learn = self._next_learn_flag() |
|
|
| |
| |
| |
| |
| if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: |
| sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() |
| cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) |
| anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) |
| |
| if _HTM_USE_FUSED: |
| for b in range(B): |
| self._regions[b].step_many_fused_cuda( |
| sdr_u8[b].__cuda_array_interface__, |
| cols_out[b].__cuda_array_interface__, |
| anom_out[b].__cuda_array_interface__, |
| learn, |
| ) |
| else: |
| for b in range(B): |
| self._regions[b].step_many_cuda( |
| sdr_u8[b].__cuda_array_interface__, |
| cols_out[b].__cuda_array_interface__, |
| anom_out[b].__cuda_array_interface__, |
| learn, |
| ) |
| |
| return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1) |
|
|
| |
| |
| sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() |
| out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) |
|
|
| def _process_one(b: int) -> None: |
| region = self._regions[b] |
| if self._use_gpu: |
| cols, anom = region.step_many_gpu(sdr_np[b], learn) |
| out[b, :, : self.n_columns] = cols |
| out[b, :, self.n_columns] = anom |
| elif _HTM_HAS_STEP_MANY: |
| |
| cols, anom = region.step_many(sdr_np[b], learn) |
| out[b, :, : self.n_columns] = cols |
| out[b, :, self.n_columns] = anom |
| else: |
| for t in range(T): |
| active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) |
| out[b, t, : self.n_columns] = active_cols |
| out[b, t, self.n_columns] = float(anomaly) |
|
|
| if B == 1: |
| _process_one(0) |
| elif self._use_gpu: |
| |
| |
| |
| |
| for b in range(B): |
| _process_one(b) |
| else: |
| |
| |
| list(self._htm_pool.map(_process_one, range(B))) |
|
|
| return torch.from_numpy(out).to(sdr.device) |
|
|
| def forward_async(self, sdr: torch.Tensor): |
| """Submit HTM work and return a handle awaitable via ``forward_await``. |
| |
| On the CAI zero-copy path (GPU tensor in, GPU region), the Rust |
| CUDA kernels are launched on cudarc's internal stream and control |
| returns **immediately** — no device synchronization. The caller's |
| next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued |
| on PyTorch's default stream and can execute while HTM kernels run |
| on the cudarc stream. ``forward_await`` performs the cross-stream |
| sync (via ``device_sync``) and assembles the output tensor only |
| when the result is actually consumed. |
| |
| For cooperative kernels (``step_many_fused_cuda``) the GPU can only |
| run one cooperative launch at a time, so kernel-level overlap with |
| default-stream work is limited. The win is **CPU-side launch |
| overlap**: instead of the CPU blocking ~10 ms waiting for HTM |
| before it can even enqueue wte/mamba, it enqueues everything up |
| front and the GPU executes back-to-back without CPU stalls. |
| |
| On the legacy CPU/numpy path, work is dispatched to a thread pool |
| as before.""" |
| B, T, D = sdr.shape |
| if D != self.input_bits: |
| raise ValueError(f"expected input_bits={self.input_bits}, got {D}") |
| self._ensure_regions(B) |
| if self.reset_each_forward: |
| self.reset() |
| learn = self._next_learn_flag() |
|
|
| if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: |
| sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() |
| cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) |
| anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) |
| |
| |
| |
| |
| |
| |
| if _HTM_USE_BATCHED_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"): |
| |
| |
| |
| try: |
| htm_rust.step_batch_fused_cuda( |
| self._regions[:B], |
| [sdr_u8[b].__cuda_array_interface__ for b in range(B)], |
| [cols_out[b].__cuda_array_interface__ for b in range(B)], |
| [anom_out[b].__cuda_array_interface__ for b in range(B)], |
| learn, |
| ) |
| except RuntimeError as _e: |
| if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e): |
| |
| |
| for b in range(B): |
| self._regions[b].step_many_fused_cuda( |
| sdr_u8[b].__cuda_array_interface__, |
| cols_out[b].__cuda_array_interface__, |
| anom_out[b].__cuda_array_interface__, |
| learn, |
| ) |
| else: |
| raise |
| elif _HTM_USE_FUSED: |
| for b in range(B): |
| self._regions[b].step_many_fused_cuda( |
| sdr_u8[b].__cuda_array_interface__, |
| cols_out[b].__cuda_array_interface__, |
| anom_out[b].__cuda_array_interface__, |
| learn, |
| ) |
| else: |
| for b in range(B): |
| self._regions[b].step_many_cuda( |
| sdr_u8[b].__cuda_array_interface__, |
| cols_out[b].__cuda_array_interface__, |
| anom_out[b].__cuda_array_interface__, |
| learn, |
| ) |
| |
| |
| return { |
| 'cuda_deferred': True, |
| 'cols_out': cols_out, |
| 'anom_out': anom_out, |
| 'region0': self._regions[0], |
| } |
|
|
| sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() |
| out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) |
|
|
| def _process_one(b): |
| region = self._regions[b] |
| if self._use_gpu: |
| cols, anom = region.step_many_gpu(sdr_np[b], learn) |
| out[b, :, : self.n_columns] = cols |
| out[b, :, self.n_columns] = anom |
| elif _HTM_HAS_STEP_MANY: |
| cols, anom = region.step_many(sdr_np[b], learn) |
| out[b, :, : self.n_columns] = cols |
| out[b, :, self.n_columns] = anom |
| else: |
| for t in range(T): |
| active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) |
| out[b, t, : self.n_columns] = active_cols |
| out[b, t, self.n_columns] = float(anomaly) |
|
|
| fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)]) |
| return {'fut': fut, 'out': out, 'device': sdr.device} |
|
|
| def forward_await(self, handle) -> torch.Tensor: |
| if handle.get('cuda_deferred'): |
| |
| |
| |
| region0 = handle['region0'] |
| if hasattr(region0, "device_sync"): |
| region0.device_sync() |
| else: |
| torch.cuda.synchronize() |
| cols_out = handle['cols_out'] |
| anom_out = handle['anom_out'] |
| return torch.cat( |
| (cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1 |
| ) |
| if 'cuda_result' in handle: |
| return handle['cuda_result'] |
| handle['fut'].result() |
| return torch.from_numpy(handle['out']).to(handle['device']) |
|
|
|
|
| if __name__ == "__main__": |
| torch.manual_seed(0) |
|
|
| |
| B, T, D = 2, 4, 16384 |
| n_columns = 2048 |
| target_active_in = int(D * 0.02) |
|
|
| layer = HTMLayer( |
| input_bits=D, |
| n_columns=n_columns, |
| cells_per_column=32, |
| batch_size=B, |
| seed=42, |
| learn=True, |
| ) |
| layer.train() |
|
|
| rng = np.random.default_rng(0) |
| sdr = np.zeros((B, T, D), dtype=bool) |
| for b in range(B): |
| for t in range(T): |
| idx = rng.choice(D, size=target_active_in, replace=False) |
| sdr[b, t, idx] = True |
| sdr_t = torch.from_numpy(sdr) |
|
|
| t0 = time.perf_counter() |
| out = layer(sdr_t) |
| dt_first = time.perf_counter() - t0 |
|
|
| assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}" |
| assert out.dtype == torch.float32, f"dtype {out.dtype}" |
|
|
| active_cols = out[..., :n_columns] |
| anomaly = out[..., n_columns] |
|
|
| col_sums = active_cols.sum(dim=-1) |
| mean_active = col_sums.float().mean().item() |
| expected = n_columns * 0.02 |
| assert 20 <= mean_active <= 60, ( |
| f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})" |
| ) |
|
|
| |
| assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}" |
|
|
| |
| t0 = time.perf_counter() |
| out2 = layer(sdr_t) |
| dt_second = time.perf_counter() - t0 |
| assert out2.shape == out.shape |
|
|
| |
| rep_layer = HTMLayer( |
| input_bits=D, |
| n_columns=n_columns, |
| batch_size=1, |
| seed=7, |
| learn=True, |
| ) |
| rep_layer.train() |
| base = torch.zeros(D, dtype=torch.bool) |
| idx = rng.choice(D, size=target_active_in, replace=False) |
| base[idx] = True |
| rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone() |
| rep_out = rep_layer(rep) |
| rep_anom = rep_out[0, :, n_columns] |
| assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}" |
| assert rep_anom[-1].item() < rep_anom[0].item(), ( |
| f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}" |
| ) |
|
|
| print("[OK] shape:", tuple(out.shape)) |
| print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})") |
| print(f"[OK] t=0 anomaly = 1.0 on all batch slots") |
| print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}") |
| print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms " |
| f"on (B={B}, T={T}, D={D})") |
|
|