//! GPU backend for HTM. //! //! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the //! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations //! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back //! to the host in one shot. //! //! TM parity with the CPU reference is approximate: //! - Segment growth: winner = cell 0 of bursting column (CPU picks //! least-used-cell with RNG tiebreak). This is a pragmatic simplification //! for GPU atomicity; learning dynamics are preserved. //! - Permanences stored as i16 (scaled 0..32767). Rounding differs from //! f32 by <= 1 ULP of the scale factor (≈ 3e-5) — inside any meaningful //! HTM learning quantum. #![cfg(feature = "gpu")] pub mod sp_gpu; pub mod tm_gpu; pub mod fused; #[cfg(test)] mod tests; use std::mem::ManuallyDrop; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods}; use crate::region::HTMRegionCore; use crate::sp::SpatialPoolerConfig; use sp_gpu::SpatialPoolerGpu; use tm_gpu::TemporalMemoryGpu; use fused::FusedState; /// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict. /// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap /// torch-owned CUDA allocations zero-copy. fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec, String)> { // `data` is a (ptr: int, readonly: bool) tuple. let data_obj = cai.get_item("data")? .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?; let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into() .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?; let ptr: u64 = data_tup.get_item(0)?.extract()?; // `shape` is a tuple of ints. let shape_obj = cai.get_item("shape")? .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?; let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into() .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?; let shape: Vec = (0..shape_tup.len()) .map(|i| shape_tup.get_item(i).and_then(|v| v.extract::())) .collect::>>()?; // `typestr` (e.g. "|u1", " PyResult { if input_bits == 0 || n_columns == 0 || cells_per_column == 0 { return Err(pyo3::exceptions::PyValueError::new_err( "input_bits, n_columns, cells_per_column must all be > 0", )); } // CPU reference for deterministic SP init. let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed); let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg; let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!( "GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}", sp_cfg.input_bits, sp_cfg.n_columns, )) })?; let dev = sp_gpu.dev_ref().clone(); let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!( "GPU TM init failed: {e:?}", )) })?; let initial_threshold = sp_gpu.initial_threshold_estimate(); let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!( "GPU fused state init failed: {e:?}", )))?; Ok(Self { sp_gpu, tm_gpu, fused_state, n_columns, input_bits, cells_per_column, }) } #[getter] fn input_bits(&self) -> usize { self.input_bits } #[getter] fn n_columns(&self) -> usize { self.n_columns } #[getter] fn cells_per_column(&self) -> usize { self.cells_per_column } /// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays /// on device; only the final (T, n_cols) mask and (T,) anomaly are copied /// to the host at the end. #[pyo3(signature = (inputs, learn=true))] fn step_many_gpu<'py>( &mut self, py: Python<'py>, inputs: PyReadonlyArray2<'py, bool>, learn: bool, ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray1>)> { let shape = inputs.shape(); if shape.len() != 2 { return Err(pyo3::exceptions::PyValueError::new_err( "inputs must be 2-D (T, input_bits)", )); } let t = shape[0]; let bits = shape[1]; if bits != self.input_bits { return Err(pyo3::exceptions::PyValueError::new_err(format!( "inputs last dim {bits} != expected input_bits {}", self.input_bits, ))); } let slice = inputs.as_slice()?; let n_cols = self.n_columns; let input_vec: Vec = slice.to_vec(); let result = py.allow_threads(|| -> Result<(Vec, Vec), String> { // 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384). let sdr_u8_all: Vec = input_vec.iter().map(|&b| b as u8).collect(); let inputs_dev = self .sp_gpu .dev_ref() .htod_sync_copy(&sdr_u8_all) .map_err(|e| format!("H2D inputs: {e:?}"))?; // 2. Allocate output buffers on device. let mut cols_dev = self.sp_gpu.dev_ref() .alloc_zeros::(t * n_cols) .map_err(|e| format!("alloc cols: {e:?}"))?; let mut anom_dev = self.sp_gpu.dev_ref() .alloc_zeros::(t) .map_err(|e| format!("alloc anom: {e:?}"))?; // 3. Run T steps of SP + TM on GPU with NO per-step host sync. self.sp_gpu.step_batch_with_tm( &inputs_dev, t, self.input_bits, learn, &mut cols_dev, &mut anom_dev, &mut self.tm_gpu, ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?; // 4. ONE D2H for the whole run (T * n_cols bytes + T floats). let cols_host: Vec = self.sp_gpu.dev_ref() .dtoh_sync_copy(&cols_dev) .map_err(|e| format!("D2H cols: {e:?}"))?; let anom_host: Vec = self.sp_gpu.dev_ref() .dtoh_sync_copy(&anom_dev) .map_err(|e| format!("D2H anom: {e:?}"))?; Ok((cols_host, anom_host)) }); let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?; let cols_f32: Vec = cols_u8.iter().map(|&b| b as f32).collect(); let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32) .reshape([t, n_cols]) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?; let anom_arr = numpy::PyArray1::from_vec_bound(py, anom); Ok((cols_arr, anom_arr)) } /// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__, /// write outputs directly into caller-allocated torch tensors. Skips the /// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() + /// two D2H copies at the end). This is the hot path for `train.py`. /// /// Contract: /// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask) /// cols_cai.shape == (T, n_columns), dtype u8 (written) /// anom_cai.shape == (T,), dtype f32 (written) /// All three tensors must live on the SAME CUDA device as this region. /// /// The torch tensors still own their memory — this method only wraps /// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop /// impl can't free pytorch's allocator. #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))] fn step_many_cuda( &mut self, py: Python<'_>, sdr_cai: &Bound<'_, PyDict>, cols_cai: &Bound<'_, PyDict>, anom_cai: &Bound<'_, PyDict>, learn: bool, ) -> PyResult<()> { let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?; let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?; let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?; // typestr sanity. numpy u1 is what torch.uint8 exports. if sdr_type != "|u1" { return Err(pyo3::exceptions::PyValueError::new_err(format!( "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}", ))); } if cols_type != "|u1" { return Err(pyo3::exceptions::PyValueError::new_err(format!( "cols_cai typestr must be '|u1' (uint8), got {cols_type}", ))); } if anom_type != " Result<(), String> { // SAFETY: // - ptrs came from torch CUDA tensors validated non-null by the // __cuda_array_interface__ contract. // - lens computed from validated shapes. // - We wrap the returned CudaSlice in ManuallyDrop so cudarc's // Drop (which calls cuMemFree) never runs against torch memory. // The underlying allocation is owned+freed by torch. // - The slices are used only for the duration of this call; // torch guarantees the backing tensors are live across it // (Python holds refs on the wrapping tensors). let inputs_dev = ManuallyDrop::new(unsafe { dev.upgrade_device_ptr::(sdr_ptr, t * input_bits) }); let mut cols_dev = ManuallyDrop::new(unsafe { dev.upgrade_device_ptr::(cols_ptr, t * n_cols) }); let mut anom_dev = ManuallyDrop::new(unsafe { dev.upgrade_device_ptr::(anom_ptr, t) }); self.sp_gpu.step_batch_with_tm( &inputs_dev, t, input_bits, learn, &mut cols_dev, &mut anom_dev, &mut self.tm_gpu, ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?; // Synchronize: kernel writes must be visible to the next torch // op that reads cols/anom. Pytorch's default stream is stream 0, // and cudarc launches on its own stream — a full device sync // is the simplest correct barrier. (Could narrow to a stream // wait event in PR 2.) // No dev.synchronize() here: caller must explicitly sync via the // `device_sync()` method (or PyTorch auto-syncs when the output // tensor is next consumed). Removing the per-launch barrier lets // subsequent GPU work (mamba3 fwd, etc.) overlap in time. Ok(()) }); result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?; Ok(()) } /// Clear TM state on the GPU. fn reset(&mut self) -> PyResult<()> { self.tm_gpu.reset().map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}")) })?; self.fused_state.reset().map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}")) }) } /// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step /// forward (SP + TM all in one). Accepts torch CUDA tensors via /// `__cuda_array_interface__` (zero-copy). Writes active-column mask + /// anomaly directly into caller-allocated torch tensors. /// /// Semantics diverge from `step_many_cuda` in one important way: column /// activation uses per-column threshold inhibition instead of global /// top-K. The threshold is EMA-adapted per column toward the sparsity /// target. See `docs/GPU_HTM.md` §Fused Kernel. #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))] fn step_many_fused_cuda( &mut self, py: Python<'_>, sdr_cai: &Bound<'_, PyDict>, cols_cai: &Bound<'_, PyDict>, anom_cai: &Bound<'_, PyDict>, learn: bool, ) -> PyResult<()> { let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?; let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?; let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?; if sdr_type != "|u1" { return Err(pyo3::exceptions::PyValueError::new_err(format!( "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}", ))); } if cols_type != "|u1" { return Err(pyo3::exceptions::PyValueError::new_err(format!( "cols_cai typestr must be '|u1' (uint8), got {cols_type}", ))); } if anom_type != " Result<(), String> { let inputs_dev = ManuallyDrop::new(unsafe { dev.upgrade_device_ptr::(sdr_ptr, t * input_bits) }); let mut cols_dev = ManuallyDrop::new(unsafe { dev.upgrade_device_ptr::(cols_ptr, t * n_cols) }); let mut anom_dev = ManuallyDrop::new(unsafe { dev.upgrade_device_ptr::(anom_ptr, t) }); fused::launch_fused( &mut self.sp_gpu, &mut self.tm_gpu, &mut self.fused_state, &inputs_dev, &mut cols_dev, &mut anom_dev, t, input_bits, learn, ).map_err(|e| format!("launch_fused: {e:?}"))?; // No dev.synchronize() here: caller must explicitly sync via the // `device_sync()` method (or PyTorch auto-syncs when the output // tensor is next consumed). Removing the per-launch barrier lets // subsequent GPU work (mamba3 fwd, etc.) overlap in time. Ok(()) }); result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?; Ok(()) } /// Explicit device synchronization — the caller must invoke this after /// all batched `step_many_*_cuda` calls complete, before reading the /// output tensors from a different CUDA stream. Equivalent to the old /// per-call `dev.synchronize()` that was removed for overlap. fn device_sync(&self) -> PyResult<()> { let dev = self.sp_gpu.dev_ref(); dev.synchronize() .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?; Ok(()) } } /// Batch B regions into ONE cooperative kernel launch. Breaks through the /// CUDA cooperative-kernel device-level serialization: a single cooperative /// launch with grid.y=B processes all regions concurrently — ~B× speedup /// over B sequential launches. /// /// All regions must have the same config (input_bits, n_columns, /// cells_per_column). Each region keeps its independent GPU state. /// Does NOT sync; caller must invoke `device_sync()` on any region /// afterwards (or rely on a downstream torch op to auto-sync). #[pyfunction] #[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))] fn step_batch_fused_cuda( py: Python<'_>, regions: Vec>, sdr_cais: Vec>, cols_cais: Vec>, anom_cais: Vec>, learn: bool, ) -> PyResult<()> { let b = regions.len(); if b == 0 { return Err(pyo3::exceptions::PyValueError::new_err("regions is empty")); } if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b { return Err(pyo3::exceptions::PyValueError::new_err( "sdr_cais / cols_cais / anom_cais length must match regions", )); } // Parse all CAI dicts; collect device pointers. Validate shapes/dtypes. let mut sdr_ptrs = Vec::with_capacity(b); let mut cols_ptrs = Vec::with_capacity(b); let mut anom_ptrs = Vec::with_capacity(b); let (input_bits, n_columns, t) = { let r0 = regions[0].bind(py).borrow(); (r0.input_bits, r0.n_columns, { let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?; if sh.len() != 2 { return Err(pyo3::exceptions::PyValueError::new_err( format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"), )); } sh[0] }) }; for i in 0..b { let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?; let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?; let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?; if sdr_type != "|u1" || cols_type != "|u1" { return Err(pyo3::exceptions::PyValueError::new_err( "sdr/cols typestr must be '|u1' (uint8)", )); } if anom_type != "> = regions.iter().map(|p| p.bind(py).borrow_mut()).collect(); // Collect raw mutable pointers — each PyRefMut exclusively borrows its // region for the lifetime of this call, so pointers stay valid and // unique. launch_fused_batched_raw only dereferences one region at a // time, not constructing an aliased slice. let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs .iter_mut() .map(|r| &mut **r as *mut HTMRegionGpu) .collect(); // No allow_threads: raw pointers aren't Send. The launch is GPU-queued // and sync'd downstream; holding the GIL for the duration is cheap. fused::launch_fused_batched_raw( &raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs, t, input_bits, learn, ) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?; Ok(()) } pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?; Ok(()) }