""" HTM torch wrapper around the pyo3 ``htm_rust`` crate. Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to ``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly score. HTM learning is Hebbian (not gradient), so the wrapper runs under ``torch.no_grad()``. Downstream layers carry gradients back to the embedding via their own learnable projection from the binary column output. Per-sequence state semantics --------------------------- Training-time forward passes are independent windows of tokens (re-sampled every step), so carrying TM state across calls would mix unrelated contexts. This layer calls ``reset()`` on every region at the top of ``forward``; the TM learns within-window temporal patterns only. Users that want cross-window continuity (e.g. eval over a long document) should instead construct the layer and drive ``step_stream`` themselves (not implemented here; the single-forward contract is sufficient for the autoresearch loop). Device handling --------------- ``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a ``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this copy is small compared to the SP/TM compute. """ from __future__ import annotations import time from concurrent.futures import ThreadPoolExecutor import numpy as np import torch import torch.nn as nn import htm_rust def _resolve_htm_rust_extension(): """Return real htm_rust extension, not the repo source-dir namespace. In repo/HF runtime cwd, `/workspace/feather/htm_rust/` can shadow the installed PyO3 extension wheel. That yields a namespace module with no `HTMRegion`, which previously triggered invalid no-learning paths. Search site-packages explicitly and load the compiled extension if the first import is only the source tree. """ if hasattr(htm_rust, "HTMRegion"): return htm_rust import importlib.util import site import sysconfig from pathlib import Path search_roots = [] for getter in (site.getsitepackages, lambda: [site.getusersitepackages()]): try: search_roots.extend(Path(p) for p in getter()) except Exception: pass try: search_roots.append(Path(sysconfig.get_paths()["purelib"])) search_roots.append(Path(sysconfig.get_paths()["platlib"])) except Exception: pass seen = set() for root in search_roots: if root in seen or not root.exists(): continue seen.add(root) for so_path in sorted(root.glob("htm_rust*.so")): spec = importlib.util.spec_from_file_location("htm_rust", so_path) if spec is None or spec.loader is None: continue mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # type: ignore[union-attr] if hasattr(mod, "HTMRegion"): return mod raise RuntimeError( "real htm_rust extension is not importable: source directory shadowed " "the wheel and no site-packages htm_rust*.so with HTMRegion was found" ) htm_rust = _resolve_htm_rust_extension() # step_many releases the GIL for the whole pass, so multiple threads can # truly run regions in parallel — wall-clock scales with B up to CPU cores. _HTM_REGION_CLS = getattr(htm_rust, "HTMRegion", None) _HTM_HAS_STEP_MANY = _HTM_REGION_CLS is not None and hasattr(_HTM_REGION_CLS, "step_many") # GPU backend: built with `maturin develop --features gpu`. One CUDA region # per batch slot, persistent device state for SP synapses. Transparent # fallback to CPU when not available. _HTM_GPU_REGION_CLS = getattr(htm_rust, "HTMRegionGpu", None) _HTM_HAS_GPU = _HTM_GPU_REGION_CLS is not None # Zero-copy CUDA path: consumes torch CUDA tensors directly via the # __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip # and the D2H of outputs. Huge win when the input SDR already lives on GPU # (which is the train.py hot path — retina is a device buffer). _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_GPU_REGION_CLS, "step_many_cuda") # Fused megakernel path: collapses all T timesteps + SP + TM into a single # CUDA launch per forward. Replaces global top-K with per-column threshold # inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel). # Opt-in via env var (default on when available). import os as _os_fused _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_GPU_REGION_CLS, "step_many_fused_cuda") _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1"))) _HTM_USE_BATCHED_FUSED = _HTM_USE_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_BATCHED_FUSED", "1"))) def _resolve_use_gpu(use_gpu: bool | None, *, cuda_available: bool) -> bool: """Resolve HTMLayer GPU use with HYDRA_HTM_USE_GPU override.""" htm_use_gpu_env = _os_fused.environ.get("HYDRA_HTM_USE_GPU", "auto").lower() if htm_use_gpu_env in {"0", "false", "no", "cpu"}: return False if htm_use_gpu_env in {"1", "true", "yes", "gpu"}: return True if use_gpu is None: return _HTM_HAS_GPU and cuda_available return bool(use_gpu) class HTMLayer(nn.Module): """Batched torch wrapper around ``htm_rust.HTMRegion``. One independent region per batch slot so temporal memory learns sequence-local patterns without cross-batch bleed. Regions grow lazily if a larger batch shows up. Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are the binary active-column mask (float32 0/1) and the last channel is the per-timestep anomaly score in [0, 1]. """ def __init__( self, input_bits: int = 16384, n_columns: int = 2048, cells_per_column: int = 32, batch_size: int = 1, seed: int = 42, learn: bool = True, reset_each_forward: bool = True, use_gpu: bool | None = None, ) -> None: super().__init__() self.input_bits = input_bits self.n_columns = n_columns self.cells_per_column = cells_per_column self.learn = learn self.learning_enabled = True self.reset_each_forward = reset_each_forward self._seed_base = seed # Learn gating: HTM learn kernels (tm_punish, tm_learn_reinforce, tm_grow) # are 56% of total HTM CUDA time. Gating them to run every N forwards # instead of every forward cuts HTM cost ~2x. Hebbian learning still # converges since the EMA accumulates over many calls. Env: # HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled). import os as _os self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))) self._forward_counter = 0 # GPU backend gate. Default: auto-detect — use GPU when the pyo3 # module was built with --features gpu AND CUDA is actually usable. # HYDRA_FORCE_HTM_CPU=1 is an operational safety valve for paid remote # canaries when the compiled CUDA HTM backend is present but unstable on # a specific hardware/runtime combination. force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1" strict_optimal = _os.environ.get("HYDRA_STRICT_OPTIMAL_COMPONENTS", "0") == "1" if strict_optimal and force_cpu: raise RuntimeError("HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires GPU HTM; HYDRA_FORCE_HTM_CPU=1 is not allowed.") if use_gpu is None: use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available() elif use_gpu and not _HTM_HAS_GPU: raise RuntimeError( "HTMLayer(use_gpu=True) but htm_rust was not built with " "--features gpu. Re-run `maturin develop --features gpu`." ) elif use_gpu and force_cpu: use_gpu = False self._use_gpu = bool(use_gpu) if strict_optimal: if not self._use_gpu: raise RuntimeError( "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires GPU HTM; " "htm_rust GPU backend and CUDA must be available." ) if not (_HTM_USE_FUSED and _HTM_USE_BATCHED_FUSED): raise RuntimeError( "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires fused batched CUDA HTM; " "set HYDRA_HTM_FUSED=1 and HYDRA_HTM_BATCHED_FUSED=1." ) if not hasattr(htm_rust, "step_batch_fused_cuda"): raise RuntimeError( "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires htm_rust.step_batch_fused_cuda; " "the current htm_rust wheel would silently fall back to per-region fused HTM." ) cls = _HTM_GPU_REGION_CLS if self._use_gpu else _HTM_REGION_CLS if cls is None: raise RuntimeError( "htm_rust does not expose HTMRegion; install/build htm_rust before constructing HTMLayer" ) self._region_cls = cls self._regions = [ cls(input_bits, n_columns, cells_per_column, seed + i) for i in range(batch_size) ] self.register_buffer("_dummy", torch.zeros(1), persistent=False) import os as _os self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16)) def _ensure_regions(self, B: int) -> None: while len(self._regions) < B: idx = len(self._regions) self._regions.append( self._region_cls( self.input_bits, self.n_columns, self.cells_per_column, self._seed_base + idx, ) ) def reset(self) -> None: """Clear TM predictive state on every region (keeps SP synapses).""" for r in self._regions: r.reset() def _next_learn_flag(self) -> bool: """Return whether this forward may mutate HTM state. Both synchronous and async HTM paths must use this same gate. Eval, validation, and factual-probe forwards therefore cannot update the persistent Hebbian state even if the layer was constructed with ``learn=True``. """ self._forward_counter += 1 return bool( self.learn and getattr(self, "learning_enabled", True) and self.training and (self._forward_counter % self._learn_every == 0) ) @torch.no_grad() def forward(self, sdr: torch.Tensor) -> torch.Tensor: B, T, D = sdr.shape if D != self.input_bits: raise ValueError(f"expected input_bits={self.input_bits}, got {D}") self._ensure_regions(B) if self.reset_each_forward: self.reset() # Learn-gate: run learn kernels only every N forwards (skips 56% of # HTM CUDA time on skip-forwards; Hebbian EMA still converges) and # never during eval/validation/factual probes. learn = self._next_learn_flag() # Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer), # so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust # kernel writes directly into torch-owned CUDA tensors via CAI. # Gives 5-10x tok/s on train.py vs the numpy path below. if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) # Pick fused (1 launch) or legacy (12*T launches) path. if _HTM_USE_FUSED: for b in range(B): self._regions[b].step_many_fused_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) else: for b in range(B): self._regions[b].step_many_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) # Assemble (B, T, n_cols+1) — keep bf16-friendly float32. return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1) # Fallback: CPU / numpy path. Kept for CPU-input case and for # builds without CAI support. sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) def _process_one(b: int) -> None: region = self._regions[b] if self._use_gpu: cols, anom = region.step_many_gpu(sdr_np[b], learn) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom elif _HTM_HAS_STEP_MANY: # Single Rust call: T steps with GIL released for the whole pass. cols, anom = region.step_many(sdr_np[b], learn) # cols (T, n_cols), anom (T,) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom else: for t in range(T): active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) out[b, t, : self.n_columns] = active_cols out[b, t, self.n_columns] = float(anomaly) if B == 1: _process_one(0) elif self._use_gpu: # GPU regions share the CUDA context; serialise to avoid contention # for stream 0. Per-region latency is dominated by kernel compute, # not threadable on a single stream cheaply — future work: one # CUDA stream per region. for b in range(B): _process_one(b) else: # Each thread runs in pure Rust under py.allow_threads, so they # parallelise to wall-clock min(B, CPU_cores). list(self._htm_pool.map(_process_one, range(B))) return torch.from_numpy(out).to(sdr.device) def forward_async(self, sdr: torch.Tensor): """Submit HTM work and return a handle awaitable via ``forward_await``. On the CAI zero-copy path (GPU tensor in, GPU region), the Rust CUDA kernels are launched on cudarc's internal stream and control returns **immediately** — no device synchronization. The caller's next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued on PyTorch's default stream and can execute while HTM kernels run on the cudarc stream. ``forward_await`` performs the cross-stream sync (via ``device_sync``) and assembles the output tensor only when the result is actually consumed. For cooperative kernels (``step_many_fused_cuda``) the GPU can only run one cooperative launch at a time, so kernel-level overlap with default-stream work is limited. The win is **CPU-side launch overlap**: instead of the CPU blocking ~10 ms waiting for HTM before it can even enqueue wte/mamba, it enqueues everything up front and the GPU executes back-to-back without CPU stalls. On the legacy CPU/numpy path, work is dispatched to a thread pool as before.""" B, T, D = sdr.shape if D != self.input_bits: raise ValueError(f"expected input_bits={self.input_bits}, got {D}") self._ensure_regions(B) if self.reset_each_forward: self.reset() learn = self._next_learn_flag() if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) # ONE cooperative kernel launch for all B regions. Breaks past # the CUDA cooperative-kernel device-level serialization (only # one cooperative kernel runs at a time). A single launch with # grid.y = B processes all regions concurrently — ~B× speedup. # Falls back to sequential dispatch if the batched entry isn't # available (older htm_rust wheel). if _HTM_USE_BATCHED_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"): # Slice self._regions to match B: _ensure_regions may have # allocated more regions than the current batch size needs # (e.g. factual eval uses smaller batches than training). try: htm_rust.step_batch_fused_cuda( self._regions[:B], [sdr_u8[b].__cuda_array_interface__ for b in range(B)], [cols_out[b].__cuda_array_interface__ for b in range(B)], [anom_out[b].__cuda_array_interface__ for b in range(B)], learn, ) except RuntimeError as _e: if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e): # Batch too large for cooperative grid. Fall back to # sequential per-region fused launches (each B=1). for b in range(B): self._regions[b].step_many_fused_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) else: raise elif _HTM_USE_FUSED: for b in range(B): self._regions[b].step_many_fused_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) else: for b in range(B): self._regions[b].step_many_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) # NO sync here — kernels are in-flight on cudarc's stream. # forward_await() will sync before the output is consumed. return { 'cuda_deferred': True, 'cols_out': cols_out, 'anom_out': anom_out, 'region0': self._regions[0], } sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) def _process_one(b): region = self._regions[b] if self._use_gpu: cols, anom = region.step_many_gpu(sdr_np[b], learn) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom elif _HTM_HAS_STEP_MANY: cols, anom = region.step_many(sdr_np[b], learn) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom else: for t in range(T): active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) out[b, t, : self.n_columns] = active_cols out[b, t, self.n_columns] = float(anomaly) fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)]) return {'fut': fut, 'out': out, 'device': sdr.device} def forward_await(self, handle) -> torch.Tensor: if handle.get('cuda_deferred'): # Cross-stream sync: block until cudarc stream finishes HTM # kernels so the output tensors are safe to read on the # default stream. region0 = handle['region0'] if hasattr(region0, "device_sync"): region0.device_sync() else: torch.cuda.synchronize() cols_out = handle['cols_out'] anom_out = handle['anom_out'] return torch.cat( (cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1 ) if 'cuda_result' in handle: return handle['cuda_result'] handle['fut'].result() return torch.from_numpy(handle['out']).to(handle['device']) if __name__ == "__main__": torch.manual_seed(0) # Smoke test: (B=2, T=4, D=16384) random 2%-sparse SDR B, T, D = 2, 4, 16384 n_columns = 2048 target_active_in = int(D * 0.02) # 327 layer = HTMLayer( input_bits=D, n_columns=n_columns, cells_per_column=32, batch_size=B, seed=42, learn=True, ) layer.train() rng = np.random.default_rng(0) sdr = np.zeros((B, T, D), dtype=bool) for b in range(B): for t in range(T): idx = rng.choice(D, size=target_active_in, replace=False) sdr[b, t, idx] = True sdr_t = torch.from_numpy(sdr) t0 = time.perf_counter() out = layer(sdr_t) dt_first = time.perf_counter() - t0 assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}" assert out.dtype == torch.float32, f"dtype {out.dtype}" active_cols = out[..., :n_columns] anomaly = out[..., n_columns] col_sums = active_cols.sum(dim=-1) # (B, T) mean_active = col_sums.float().mean().item() expected = n_columns * 0.02 # ≈ 40.96 assert 20 <= mean_active <= 60, ( f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})" ) # t=0 has no TM prediction → anomaly = 1.0 on every batch slot. assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}" # Second forward on same (reset) layer: identical shapes, deterministic re-run possible. t0 = time.perf_counter() out2 = layer(sdr_t) dt_second = time.perf_counter() - t0 assert out2.shape == out.shape # Repeating-sequence anomaly decay check — one region, T=8 repeats of same pattern. rep_layer = HTMLayer( input_bits=D, n_columns=n_columns, batch_size=1, seed=7, learn=True, ) rep_layer.train() base = torch.zeros(D, dtype=torch.bool) idx = rng.choice(D, size=target_active_in, replace=False) base[idx] = True rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone() rep_out = rep_layer(rep) rep_anom = rep_out[0, :, n_columns] assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}" assert rep_anom[-1].item() < rep_anom[0].item(), ( f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}" ) print("[OK] shape:", tuple(out.shape)) print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})") print(f"[OK] t=0 anomaly = 1.0 on all batch slots") print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}") print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms " f"on (B={B}, T={T}, D={D})")