Spaces:
Runtime error
Runtime error
| """ | |
| HTM torch wrapper around the pyo3 ``htm_rust`` crate. | |
| Exposes ``HTMLayer``, a ``torch.nn.Module`` that batches calls to | |
| ``htm_rust.HTMRegion.step`` across a ``(B, T, input_bits)`` boolean SDR stream | |
| and returns ``(B, T, n_columns + 1)`` where the last channel is the anomaly | |
| score. HTM learning is Hebbian (not gradient), so the wrapper runs under | |
| ``torch.no_grad()``. Downstream layers carry gradients back to the embedding | |
| via their own learnable projection from the binary column output. | |
| Per-sequence state semantics | |
| --------------------------- | |
| Training-time forward passes are independent windows of tokens (re-sampled | |
| every step), so carrying TM state across calls would mix unrelated contexts. | |
| This layer calls ``reset()`` on every region at the top of ``forward``; the | |
| TM learns within-window temporal patterns only. Users that want cross-window | |
| continuity (e.g. eval over a long document) should instead construct the | |
| layer and drive ``step_stream`` themselves (not implemented here; the | |
| single-forward contract is sufficient for the autoresearch loop). | |
| Device handling | |
| --------------- | |
| ``htm_rust`` runs on CPU. If ``sdr`` lives on CUDA we pay a | |
| ``sdr.cpu().numpy()`` round-trip per forward. The return tensor is cast back | |
| to ``sdr.device``. For expected use (batch<=32, T<=2048, bits=16384) this | |
| copy is small compared to the SP/TM compute. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import htm_rust | |
| # step_many releases the GIL for the whole pass, so multiple threads can | |
| # truly run regions in parallel — wall-clock scales with B up to CPU cores. | |
| _HTM_HAS_STEP_MANY = hasattr(htm_rust.HTMRegion, "step_many") | |
| # GPU backend: built with `maturin develop --features gpu`. One CUDA region | |
| # per batch slot, persistent device state for SP synapses. Transparent | |
| # fallback to CPU when not available. | |
| _HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu") | |
| # Zero-copy CUDA path: consumes torch CUDA tensors directly via the | |
| # __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip | |
| # and the D2H of outputs. Huge win when the input SDR already lives on GPU | |
| # (which is the train.py hot path — retina is a device buffer). | |
| _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_cuda") | |
| # Fused megakernel path: collapses all T timesteps + SP + TM into a single | |
| # CUDA launch per forward. Replaces global top-K with per-column threshold | |
| # inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel). | |
| # Opt-in via env var (default on when available). | |
| import os as _os_fused | |
| _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(htm_rust.HTMRegionGpu, "step_many_fused_cuda") | |
| _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1"))) | |
| _HTM_USE_BATCHED_FUSED = _HTM_USE_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_BATCHED_FUSED", "1"))) | |
| class HTMLayer(nn.Module): | |
| """Batched torch wrapper around ``htm_rust.HTMRegion``. | |
| One independent region per batch slot so temporal memory learns | |
| sequence-local patterns without cross-batch bleed. Regions grow | |
| lazily if a larger batch shows up. | |
| Output is ``(B, T, n_columns + 1)``: first ``n_columns`` channels are | |
| the binary active-column mask (float32 0/1) and the last channel is | |
| the per-timestep anomaly score in [0, 1]. | |
| """ | |
| def __init__( | |
| self, | |
| input_bits: int = 16384, | |
| n_columns: int = 2048, | |
| cells_per_column: int = 32, | |
| batch_size: int = 1, | |
| seed: int = 42, | |
| learn: bool = True, | |
| reset_each_forward: bool = True, | |
| use_gpu: bool | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.input_bits = input_bits | |
| self.n_columns = n_columns | |
| self.cells_per_column = cells_per_column | |
| self.learn = learn | |
| self.reset_each_forward = reset_each_forward | |
| self._seed_base = seed | |
| # Learn gating: HTM learn kernels (tm_punish, tm_learn_reinforce, tm_grow) | |
| # are 56% of total HTM CUDA time. Gating them to run every N forwards | |
| # instead of every forward cuts HTM cost ~2x. Hebbian learning still | |
| # converges since the EMA accumulates over many calls. Env: | |
| # HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled). | |
| import os as _os | |
| self._learn_every = max(1, int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))) | |
| self._forward_counter = 0 | |
| # GPU backend gate. Default: auto-detect — use GPU when the pyo3 | |
| # module was built with --features gpu AND CUDA is actually usable. | |
| # HYDRA_FORCE_HTM_CPU=1 is an operational safety valve for paid remote | |
| # canaries when the compiled CUDA HTM backend is present but unstable on | |
| # a specific hardware/runtime combination. | |
| force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1" | |
| if use_gpu is None: | |
| use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available() | |
| elif use_gpu and not _HTM_HAS_GPU: | |
| raise RuntimeError( | |
| "HTMLayer(use_gpu=True) but htm_rust was not built with " | |
| "--features gpu. Re-run `maturin develop --features gpu`." | |
| ) | |
| elif use_gpu and force_cpu: | |
| use_gpu = False | |
| self._use_gpu = bool(use_gpu) | |
| cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion | |
| self._region_cls = cls | |
| self._regions = [ | |
| cls(input_bits, n_columns, cells_per_column, seed + i) | |
| for i in range(batch_size) | |
| ] | |
| self.register_buffer("_dummy", torch.zeros(1), persistent=False) | |
| import os as _os | |
| self._htm_pool = ThreadPoolExecutor(max_workers=min(_os.cpu_count() or 4, 16)) | |
| def _ensure_regions(self, B: int) -> None: | |
| while len(self._regions) < B: | |
| idx = len(self._regions) | |
| self._regions.append( | |
| self._region_cls( | |
| self.input_bits, | |
| self.n_columns, | |
| self.cells_per_column, | |
| self._seed_base + idx, | |
| ) | |
| ) | |
| def reset(self) -> None: | |
| """Clear TM predictive state on every region (keeps SP synapses).""" | |
| for r in self._regions: | |
| r.reset() | |
| def forward(self, sdr: torch.Tensor) -> torch.Tensor: | |
| B, T, D = sdr.shape | |
| if D != self.input_bits: | |
| raise ValueError(f"expected input_bits={self.input_bits}, got {D}") | |
| self._ensure_regions(B) | |
| if self.reset_each_forward: | |
| self.reset() | |
| # Learn-gate: run learn kernels only every N forwards (skips 56% of | |
| # HTM CUDA time on skip-forwards; Hebbian EMA still converges). | |
| self._forward_counter += 1 | |
| learn = bool( | |
| self.learn | |
| and self.training | |
| and (self._forward_counter % self._learn_every == 0) | |
| ) | |
| # Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer), | |
| # so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust | |
| # kernel writes directly into torch-owned CUDA tensors via CAI. | |
| # Gives 5-10x tok/s on train.py vs the numpy path below. | |
| if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: | |
| sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() | |
| cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) | |
| anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) | |
| # Pick fused (1 launch) or legacy (12*T launches) path. | |
| if _HTM_USE_FUSED: | |
| for b in range(B): | |
| self._regions[b].step_many_fused_cuda( | |
| sdr_u8[b].__cuda_array_interface__, | |
| cols_out[b].__cuda_array_interface__, | |
| anom_out[b].__cuda_array_interface__, | |
| learn, | |
| ) | |
| else: | |
| for b in range(B): | |
| self._regions[b].step_many_cuda( | |
| sdr_u8[b].__cuda_array_interface__, | |
| cols_out[b].__cuda_array_interface__, | |
| anom_out[b].__cuda_array_interface__, | |
| learn, | |
| ) | |
| # Assemble (B, T, n_cols+1) — keep bf16-friendly float32. | |
| return torch.cat((cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1) | |
| # Fallback: CPU / numpy path. Kept for CPU-input case and for | |
| # builds without CAI support. | |
| sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() | |
| out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) | |
| def _process_one(b: int) -> None: | |
| region = self._regions[b] | |
| if self._use_gpu: | |
| cols, anom = region.step_many_gpu(sdr_np[b], learn) | |
| out[b, :, : self.n_columns] = cols | |
| out[b, :, self.n_columns] = anom | |
| elif _HTM_HAS_STEP_MANY: | |
| # Single Rust call: T steps with GIL released for the whole pass. | |
| cols, anom = region.step_many(sdr_np[b], learn) # cols (T, n_cols), anom (T,) | |
| out[b, :, : self.n_columns] = cols | |
| out[b, :, self.n_columns] = anom | |
| else: | |
| for t in range(T): | |
| active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) | |
| out[b, t, : self.n_columns] = active_cols | |
| out[b, t, self.n_columns] = float(anomaly) | |
| if B == 1: | |
| _process_one(0) | |
| elif self._use_gpu: | |
| # GPU regions share the CUDA context; serialise to avoid contention | |
| # for stream 0. Per-region latency is dominated by kernel compute, | |
| # not threadable on a single stream cheaply — future work: one | |
| # CUDA stream per region. | |
| for b in range(B): | |
| _process_one(b) | |
| else: | |
| # Each thread runs in pure Rust under py.allow_threads, so they | |
| # parallelise to wall-clock min(B, CPU_cores). | |
| list(self._htm_pool.map(_process_one, range(B))) | |
| return torch.from_numpy(out).to(sdr.device) | |
| def forward_async(self, sdr: torch.Tensor): | |
| """Submit HTM work and return a handle awaitable via ``forward_await``. | |
| On the CAI zero-copy path (GPU tensor in, GPU region), the Rust | |
| CUDA kernels are launched on cudarc's internal stream and control | |
| returns **immediately** — no device synchronization. The caller's | |
| next GPU ops (embedding lookup, Mamba forward, etc.) are enqueued | |
| on PyTorch's default stream and can execute while HTM kernels run | |
| on the cudarc stream. ``forward_await`` performs the cross-stream | |
| sync (via ``device_sync``) and assembles the output tensor only | |
| when the result is actually consumed. | |
| For cooperative kernels (``step_many_fused_cuda``) the GPU can only | |
| run one cooperative launch at a time, so kernel-level overlap with | |
| default-stream work is limited. The win is **CPU-side launch | |
| overlap**: instead of the CPU blocking ~10 ms waiting for HTM | |
| before it can even enqueue wte/mamba, it enqueues everything up | |
| front and the GPU executes back-to-back without CPU stalls. | |
| On the legacy CPU/numpy path, work is dispatched to a thread pool | |
| as before.""" | |
| B, T, D = sdr.shape | |
| if D != self.input_bits: | |
| raise ValueError(f"expected input_bits={self.input_bits}, got {D}") | |
| self._ensure_regions(B) | |
| if self.reset_each_forward: | |
| self.reset() | |
| learn = bool(self.learn and self.training) | |
| if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda: | |
| sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous() | |
| cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device) | |
| anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device) | |
| # ONE cooperative kernel launch for all B regions. Breaks past | |
| # the CUDA cooperative-kernel device-level serialization (only | |
| # one cooperative kernel runs at a time). A single launch with | |
| # grid.y = B processes all regions concurrently — ~B× speedup. | |
| # Falls back to sequential dispatch if the batched entry isn't | |
| # available (older htm_rust wheel). | |
| if _HTM_USE_BATCHED_FUSED and hasattr(htm_rust, "step_batch_fused_cuda"): | |
| # Slice self._regions to match B: _ensure_regions may have | |
| # allocated more regions than the current batch size needs | |
| # (e.g. factual eval uses smaller batches than training). | |
| try: | |
| htm_rust.step_batch_fused_cuda( | |
| self._regions[:B], | |
| [sdr_u8[b].__cuda_array_interface__ for b in range(B)], | |
| [cols_out[b].__cuda_array_interface__ for b in range(B)], | |
| [anom_out[b].__cuda_array_interface__ for b in range(B)], | |
| learn, | |
| ) | |
| except RuntimeError as _e: | |
| if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e): | |
| # Batch too large for cooperative grid. Fall back to | |
| # sequential per-region fused launches (each B=1). | |
| for b in range(B): | |
| self._regions[b].step_many_fused_cuda( | |
| sdr_u8[b].__cuda_array_interface__, | |
| cols_out[b].__cuda_array_interface__, | |
| anom_out[b].__cuda_array_interface__, | |
| learn, | |
| ) | |
| else: | |
| raise | |
| elif _HTM_USE_FUSED: | |
| for b in range(B): | |
| self._regions[b].step_many_fused_cuda( | |
| sdr_u8[b].__cuda_array_interface__, | |
| cols_out[b].__cuda_array_interface__, | |
| anom_out[b].__cuda_array_interface__, | |
| learn, | |
| ) | |
| else: | |
| for b in range(B): | |
| self._regions[b].step_many_cuda( | |
| sdr_u8[b].__cuda_array_interface__, | |
| cols_out[b].__cuda_array_interface__, | |
| anom_out[b].__cuda_array_interface__, | |
| learn, | |
| ) | |
| # NO sync here — kernels are in-flight on cudarc's stream. | |
| # forward_await() will sync before the output is consumed. | |
| return { | |
| 'cuda_deferred': True, | |
| 'cols_out': cols_out, | |
| 'anom_out': anom_out, | |
| 'region0': self._regions[0], | |
| } | |
| sdr_np = sdr.detach().cpu().contiguous().to(torch.bool).numpy() | |
| out = np.zeros((B, T, self.n_columns + 1), dtype=np.float32) | |
| def _process_one(b): | |
| region = self._regions[b] | |
| if self._use_gpu: | |
| cols, anom = region.step_many_gpu(sdr_np[b], learn) | |
| out[b, :, : self.n_columns] = cols | |
| out[b, :, self.n_columns] = anom | |
| elif _HTM_HAS_STEP_MANY: | |
| cols, anom = region.step_many(sdr_np[b], learn) | |
| out[b, :, : self.n_columns] = cols | |
| out[b, :, self.n_columns] = anom | |
| else: | |
| for t in range(T): | |
| active_cols, _ac, _pc, anomaly = region.step(sdr_np[b, t], learn) | |
| out[b, t, : self.n_columns] = active_cols | |
| out[b, t, self.n_columns] = float(anomaly) | |
| fut = self._htm_pool.submit(lambda: [_process_one(b) for b in range(B)]) | |
| return {'fut': fut, 'out': out, 'device': sdr.device} | |
| def forward_await(self, handle) -> torch.Tensor: | |
| if handle.get('cuda_deferred'): | |
| # Cross-stream sync: block until cudarc stream finishes HTM | |
| # kernels so the output tensors are safe to read on the | |
| # default stream. | |
| region0 = handle['region0'] | |
| if hasattr(region0, "device_sync"): | |
| region0.device_sync() | |
| else: | |
| torch.cuda.synchronize() | |
| cols_out = handle['cols_out'] | |
| anom_out = handle['anom_out'] | |
| return torch.cat( | |
| (cols_out.to(torch.float32), anom_out.unsqueeze(-1)), dim=-1 | |
| ) | |
| if 'cuda_result' in handle: | |
| return handle['cuda_result'] | |
| handle['fut'].result() | |
| return torch.from_numpy(handle['out']).to(handle['device']) | |
| if __name__ == "__main__": | |
| torch.manual_seed(0) | |
| # Smoke test: (B=2, T=4, D=16384) random 2%-sparse SDR | |
| B, T, D = 2, 4, 16384 | |
| n_columns = 2048 | |
| target_active_in = int(D * 0.02) # 327 | |
| layer = HTMLayer( | |
| input_bits=D, | |
| n_columns=n_columns, | |
| cells_per_column=32, | |
| batch_size=B, | |
| seed=42, | |
| learn=True, | |
| ) | |
| layer.train() | |
| rng = np.random.default_rng(0) | |
| sdr = np.zeros((B, T, D), dtype=bool) | |
| for b in range(B): | |
| for t in range(T): | |
| idx = rng.choice(D, size=target_active_in, replace=False) | |
| sdr[b, t, idx] = True | |
| sdr_t = torch.from_numpy(sdr) | |
| t0 = time.perf_counter() | |
| out = layer(sdr_t) | |
| dt_first = time.perf_counter() - t0 | |
| assert out.shape == (B, T, n_columns + 1), f"shape {out.shape}" | |
| assert out.dtype == torch.float32, f"dtype {out.dtype}" | |
| active_cols = out[..., :n_columns] | |
| anomaly = out[..., n_columns] | |
| col_sums = active_cols.sum(dim=-1) # (B, T) | |
| mean_active = col_sums.float().mean().item() | |
| expected = n_columns * 0.02 # ≈ 40.96 | |
| assert 20 <= mean_active <= 60, ( | |
| f"active columns per step out of 2% band: {mean_active:.1f} (expected ~{expected:.1f})" | |
| ) | |
| # t=0 has no TM prediction → anomaly = 1.0 on every batch slot. | |
| assert torch.allclose(anomaly[:, 0], torch.ones(B)), f"t=0 anomaly {anomaly[:, 0]}" | |
| # Second forward on same (reset) layer: identical shapes, deterministic re-run possible. | |
| t0 = time.perf_counter() | |
| out2 = layer(sdr_t) | |
| dt_second = time.perf_counter() - t0 | |
| assert out2.shape == out.shape | |
| # Repeating-sequence anomaly decay check — one region, T=8 repeats of same pattern. | |
| rep_layer = HTMLayer( | |
| input_bits=D, | |
| n_columns=n_columns, | |
| batch_size=1, | |
| seed=7, | |
| learn=True, | |
| ) | |
| rep_layer.train() | |
| base = torch.zeros(D, dtype=torch.bool) | |
| idx = rng.choice(D, size=target_active_in, replace=False) | |
| base[idx] = True | |
| rep = base.unsqueeze(0).unsqueeze(0).expand(1, 16, D).clone() | |
| rep_out = rep_layer(rep) | |
| rep_anom = rep_out[0, :, n_columns] | |
| assert rep_anom[0].item() > 0.5, f"anomaly at t=0 should be high, got {rep_anom[0]:.3f}" | |
| assert rep_anom[-1].item() < rep_anom[0].item(), ( | |
| f"anomaly should decay on repeats: first={rep_anom[0]:.3f} last={rep_anom[-1]:.3f}" | |
| ) | |
| print("[OK] shape:", tuple(out.shape)) | |
| print(f"[OK] mean active cols/step: {mean_active:.2f} (target ~{expected:.1f})") | |
| print(f"[OK] t=0 anomaly = 1.0 on all batch slots") | |
| print(f"[OK] repeating-sequence anomaly: first={rep_anom[0]:.3f} -> last={rep_anom[-1]:.3f}") | |
| print(f"[OK] forward wall-clock: first={dt_first*1000:.1f}ms second={dt_second*1000:.1f}ms " | |
| f"on (B={B}, T={T}, D={D})") | |