""" HTM torch wrapper around the pyo3 ``htm_rust`` crate. Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to ``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly score. HTM learning is Hebbian (not gradient), so the wrapper runs under ``torch.no_grad()``. Downstream layers carry gradients back to the embedding via their own learnable projection from the binary column output. Per-sequence state semantics --------------------------- Training-time forward passes are independent windows of tokens (re-sampled every step), so carrying TM state across calls would mix unrelated contexts. This layer calls ``reset()`` on every region at the top of ``forward``; the TM learns within-window temporal patterns only. Users that want cross-window continuity (e.g. eval over a long document) should instead construct the layer and drive ``step_stream`` themselves (not implemented here; the single-forward contract is sufficient for the autoresearch loop). Device handling --------------- ``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a ``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this copy is small compared to the SP/TM compute. """ from __future__ import annotations import time from concurrent.futures import ThreadPoolExecutor import numpy as np import torch import torch.nn as nn import htm_rust # step_many releases the GIL for the whole pass, so multiple threads can # truly run regions in parallel — wall-clock scales with B up to CPU cores. _HTM_HAS_STEP_MANY = hasattr(htm_rust.HTMRegion, "step_many") # GPU backend: built with `maturin develop --features gpu`. One CUDA region # per batch slot, persistent device state for SP synapses. Transparent # fallback to CPU when not available. _HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu") # Zero-copy CUDA path: consumes torch CUDA tensors directly via the # __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip # and the D2H of outputs. Huge win when the input SDR already lives on GPU # (which is the train.py hot path — retina is a device buffer). _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_cuda") # Fused megakernel path: collapses all T timesteps + SP + TM into a single # CUDA launch per forward. Replaces global top-K with per-column threshold # inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel). # Opt-in via env var (default on when available). import os as _os_fused _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_fused_cuda") _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1"))) _HTM_USE_BATCHED_FUSED = _HTM_USE_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_BATCHED_FUSED", "1"))) class HTMLayer(nn.Module): """Batched torch wrapper around ``htm_rust.HTMRegion``. One independent region per batch slot so temporal memory learns sequence-local patterns without cross-batch bleed. Regions grow lazily if a larger batch shows up. Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are the binary active-column mask (float32 0/1) and the last channel is the per-timestep anomaly score in [0, 1]. """ def __init__( self, input_bits: int = 16384, n_columns: int = 2048, cells_per_column: int = 32, batch_size: int = 1, seed: int = 42, learn: bool = True, reset_each_forward: bool = True, use_gpu: bool | None = None, ) -> None: super().__init__() self.input_bits = input_bits self.n_columns = n_columns self.cells_per_column = cells_per_column self.learn = learn self.reset_each_forward = reset_each_forward self._seed_base = seed # Learn gating: HTM learn kernels (tm_punish, tm_learn_reinforce, tm_grow) # are 56% of total HTM CUDA time. Gating them to run every N forwards # instead of every forward cuts HTM cost ~2x. Hebbian learning still # converges since the EMA accumulates over many calls. Env: # HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled). import os as _os self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))) self._forward_counter = 0 # GPU backend gate. Default: auto-detect — use GPU when the pyo3 # module was built with --features gpu AND CUDA is actually usable. # HYDRA_FORCE_HTM_CPU=1 is an operational safety valve for paid remote # canaries when the compiled CUDA HTM backend is present but unstable on # a specific hardware/runtime combination. force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1" if use_gpu is None: use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available() elif use_gpu and not _HTM_HAS_GPU: raise RuntimeError( "HTMLayer(use_gpu=True) but htm_rust was not built with " "--features gpu. Re-run `maturin develop --features gpu`." ) elif use_gpu and force_cpu: use_gpu = False self._use_gpu = bool(use_gpu) cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion self._region_cls = cls self._regions = [ cls(input_bits, n_columns, cells_per_column, seed + i) for i in range(batch_size) ] self.register_buffer("_dummy", torch.zeros(1), persistent=False) import os as _os self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16)) def _ensure_regions(self, B: int) -> None: while len(self._regions) < B: idx = len(self._regions) self._regions.append( self._region_cls( self.input_bits, self.n_columns, self.cells_per_column, self._seed_base + idx, ) ) def reset(self) -> None: """Clear TM predictive state on every region (keeps SP synapses).""" for r in self._regions: r.reset() @torch.no_grad() def forward(self, sdr: torch.Tensor) -> torch.Tensor: B, T, D = sdr.shape if D != self.input_bits: raise ValueError(f"expected input_bits={self.input_bits}, got {D}") self._ensure_regions(B) if self.reset_each_forward: self.reset() # Learn-gate: run learn kernels only every N forwards (skips 56% of # HTM CUDA time on skip-forwards; Hebbian EMA still converges). self._forward_counter += 1 learn = bool( self.learn and self.training and (self._forward_counter % self._learn_every == 0) ) # Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer), # so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust # kernel writes directly into torch-owned CUDA tensors via CAI. # Gives 5-10x tok/s on train.py vs the numpy path below. if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) # Pick fused (1 launch) or legacy (12*T launches) path. if _HTM_USE_FUSED: for b in range(B): self._regions[b].step_many_fused_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) else: for b in range(B): self._regions[b].step_many_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) # Assemble (B, T, n_cols+1) — keep bf16-friendly float32. return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1) # Fallback: CPU / numpy path. Kept for CPU-input case and for # builds without CAI support. sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) def _process_one(b: int) -> None: region = self._regions[b] if self._use_gpu: cols, anom = region.step_many_gpu(sdr_np[b], learn) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom elif _HTM_HAS_STEP_MANY: # Single Rust call: T steps with GIL released for the whole pass. cols, anom = region.step_many(sdr_np[b], learn) # cols (T, n_cols), anom (T,) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom else: for t in range(T): active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) out[b, t, : self.n_columns] = active_cols out[b, t, self.n_columns] = float(anomaly) if B == 1: _process_one(0) elif self._use_gpu: # GPU regions share the CUDA context; serialise to avoid contention # for stream 0. Per-region latency is dominated by kernel compute, # not threadable on a single stream cheaply — future work: one # CUDA stream per region. for b in range(B): _process_one(b) else: # Each thread runs in pure Rust under py.allow_threads, so they # parallelise to wall-clock min(B, CPU_cores). list(self._htm_pool.map(_process_one, range(B))) return torch.from_numpy(out).to(sdr.device) def forward_async(self, sdr: torch.Tensor): """Submit HTM work and return a handle awaitable via ``forward_await``. On the CAI zero-copy path (GPU tensor in, GPU region), the Rust CUDA kernels are launched on cudarc's internal stream and control returns **immediately** — no device synchronization. The caller's next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued on PyTorch's default stream and can execute while HTM kernels run on the cudarc stream. ``forward_await`` performs the cross-stream sync (via ``device_sync``) and assembles the output tensor only when the result is actually consumed. For cooperative kernels (``step_many_fused_cuda``) the GPU can only run one cooperative launch at a time, so kernel-level overlap with default-stream work is limited. The win is **CPU-side launch overlap**: instead of the CPU blocking ~10 ms waiting for HTM before it can even enqueue wte/mamba, it enqueues everything up front and the GPU executes back-to-back without CPU stalls. On the legacy CPU/numpy path, work is dispatched to a thread pool as before.""" B, T, D = sdr.shape if D != self.input_bits: raise ValueError(f"expected input_bits={self.input_bits}, got {D}") self._ensure_regions(B) if self.reset_each_forward: self.reset() learn = bool(self.learn and self.training) if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) # ONE cooperative kernel launch for all B regions. Breaks past # the CUDA cooperative-kernel device-level serialization (only # one cooperative kernel runs at a time). A single launch with # grid.y = B processes all regions concurrently — ~B× speedup. # Falls back to sequential dispatch if the batched entry isn't # available (older htm_rust wheel). if _HTM_USE_BATCHED_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"): # Slice self._regions to match B: _ensure_regions may have # allocated more regions than the current batch size needs # (e.g. factual eval uses smaller batches than training). try: htm_rust.step_batch_fused_cuda( self._regions[:B], [sdr_u8[b].__cuda_array_interface__ for b in range(B)], [cols_out[b].__cuda_array_interface__ for b in range(B)], [anom_out[b].__cuda_array_interface__ for b in range(B)], learn, ) except RuntimeError as _e: if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e): # Batch too large for cooperative grid. Fall back to # sequential per-region fused launches (each B=1). for b in range(B): self._regions[b].step_many_fused_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) else: raise elif _HTM_USE_FUSED: for b in range(B): self._regions[b].step_many_fused_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) else: for b in range(B): self._regions[b].step_many_cuda( sdr_u8[b].__cuda_array_interface__, cols_out[b].__cuda_array_interface__, anom_out[b].__cuda_array_interface__, learn, ) # NO sync here — kernels are in-flight on cudarc's stream. # forward_await() will sync before the output is consumed. return { 'cuda_deferred': True, 'cols_out': cols_out, 'anom_out': anom_out, 'region0': self._regions[0], } sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) def _process_one(b): region = self._regions[b] if self._use_gpu: cols, anom = region.step_many_gpu(sdr_np[b], learn) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom elif _HTM_HAS_STEP_MANY: cols, anom = region.step_many(sdr_np[b], learn) out[b, :, : self.n_columns] = cols out[b, :, self.n_columns] = anom else: for t in range(T): active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) out[b, t, : self.n_columns] = active_cols out[b, t, self.n_columns] = float(anomaly) fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)]) return {'fut': fut, 'out': out, 'device': sdr.device} def forward_await(self, handle) -> torch.Tensor: if handle.get('cuda_deferred'): # Cross-stream sync: block until cudarc stream finishes HTM # kernels so the output tensors are safe to read on the # default stream. region0 = handle['region0'] if hasattr(region0, "device_sync"): region0.device_sync() else: torch.cuda.synchronize() cols_out = handle['cols_out'] anom_out = handle['anom_out'] return torch.cat( (cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1 ) if 'cuda_result' in handle: return handle['cuda_result'] handle['fut'].result() return torch.from_numpy(handle['out']).to(handle['device']) if __name__ == "__main__": torch.manual_seed(0) # Smoke test: (B=2, T=4, D=16384) random 2%-sparse SDR B, T, D = 2, 4, 16384 n_columns = 2048 target_active_in = int(D * 0.02) # 327 layer = HTMLayer( input_bits=D, n_columns=n_columns, cells_per_column=32, batch_size=B, seed=42, learn=True, ) layer.train() rng = np.random.default_rng(0) sdr = np.zeros((B, T, D), dtype=bool) for b in range(B): for t in range(T): idx = rng.choice(D, size=target_active_in, replace=False) sdr[b, t, idx] = True sdr_t = torch.from_numpy(sdr) t0 = time.perf_counter() out = layer(sdr_t) dt_first = time.perf_counter() - t0 assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}" assert out.dtype == torch.float32, f"dtype {out.dtype}" active_cols = out[..., :n_columns] anomaly = out[..., n_columns] col_sums = active_cols.sum(dim=-1) # (B, T) mean_active = col_sums.float().mean().item() expected = n_columns * 0.02 # ≈ 40.96 assert 20 <= mean_active <= 60, ( f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})" ) # t=0 has no TM prediction → anomaly = 1.0 on every batch slot. assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}" # Second forward on same (reset) layer: identical shapes, deterministic re-run possible. t0 = time.perf_counter() out2 = layer(sdr_t) dt_second = time.perf_counter() - t0 assert out2.shape == out.shape # Repeating-sequence anomaly decay check — one region, T=8 repeats of same pattern. rep_layer = HTMLayer( input_bits=D, n_columns=n_columns, batch_size=1, seed=7, learn=True, ) rep_layer.train() base = torch.zeros(D, dtype=torch.bool) idx = rng.choice(D, size=target_active_in, replace=False) base[idx] = True rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone() rep_out = rep_layer(rep) rep_anom = rep_out[0, :, n_columns] assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}" assert rep_anom[-1].item() < rep_anom[0].item(), ( f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}" ) print("[OK] shape:", tuple(out.shape)) print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})") print(f"[OK] t=0 anomaly = 1.0 on all batch slots") print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}") print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms " f"on (B={B}, T={T}, D={D})")