icarus112's picture
Upload folder using huggingface_hub
fa198f8 verified
//! GPU backend for HTM.
//!
//! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
//! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
//! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
//! to the host in one shot.
//!
//! TM parity with the CPU reference is approximate:
//! - Segment growth: winner = cell 0 of bursting column (CPU picks
//! least-used-cell with RNG tiebreak). This is a pragmatic simplification
//! for GPU atomicity; learning dynamics are preserved.
//! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
//! f32 by <= 1 ULP of the scale factor (≈ 3e-5) — inside any meaningful
//! HTM learning quantum.
#![cfg(feature = "gpu")]
pub mod sp_gpu;
pub mod tm_gpu;
pub mod fused;
#[cfg(test)]
mod tests;
use std::mem::ManuallyDrop;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
use crate::region::HTMRegionCore;
use crate::sp::SpatialPoolerConfig;
use sp_gpu::SpatialPoolerGpu;
use tm_gpu::TemporalMemoryGpu;
use fused::FusedState;
/// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
/// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
/// torch-owned CUDA allocations zero-copy.
fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
// `data` is a (ptr: int, readonly: bool) tuple.
let data_obj = cai.get_item("data")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
let ptr: u64 = data_tup.get_item(0)?.extract()?;
// `shape` is a tuple of ints.
let shape_obj = cai.get_item("shape")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
let shape: Vec<usize> = (0..shape_tup.len())
.map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
.collect::<PyResult<Vec<_>>>()?;
// `typestr` (e.g. "|u1", "<f4").
let typestr_obj = cai.get_item("typestr")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
let typestr: String = typestr_obj.extract()?;
// Reject non-contiguous tensors — we don't handle strides.
if let Some(strides) = cai.get_item("strides")? {
if !strides.is_none() {
return Err(pyo3::exceptions::PyValueError::new_err(
"CAI 'strides' must be None (tensor must be contiguous)",
));
}
}
Ok((ptr, shape, typestr))
}
/// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
#[pyclass(module = "htm_rust")]
pub struct HTMRegionGpu {
sp_gpu: SpatialPoolerGpu,
tm_gpu: TemporalMemoryGpu,
fused_state: FusedState,
n_columns: usize,
input_bits: usize,
cells_per_column: usize,
}
#[pymethods]
impl HTMRegionGpu {
#[new]
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
fn new(
input_bits: usize,
n_columns: usize,
cells_per_column: usize,
seed: u64,
) -> PyResult<Self> {
if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"input_bits, n_columns, cells_per_column must all be > 0",
));
}
// CPU reference for deterministic SP init.
let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
sp_cfg.input_bits, sp_cfg.n_columns,
))
})?;
let dev = sp_gpu.dev_ref().clone();
let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"GPU TM init failed: {e:?}",
))
})?;
let initial_threshold = sp_gpu.initial_threshold_estimate();
let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
"GPU fused state init failed: {e:?}",
)))?;
Ok(Self {
sp_gpu,
tm_gpu,
fused_state,
n_columns,
input_bits,
cells_per_column,
})
}
#[getter] fn input_bits(&self) -> usize { self.input_bits }
#[getter] fn n_columns(&self) -> usize { self.n_columns }
#[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
/// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
/// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
/// to the host at the end.
#[pyo3(signature = (inputs, learn=true))]
fn step_many_gpu<'py>(
&mut self,
py: Python<'py>,
inputs: PyReadonlyArray2<'py, bool>,
learn: bool,
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
let shape = inputs.shape();
if shape.len() != 2 {
return Err(pyo3::exceptions::PyValueError::new_err(
"inputs must be 2-D (T, input_bits)",
));
}
let t = shape[0];
let bits = shape[1];
if bits != self.input_bits {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"inputs last dim {bits} != expected input_bits {}",
self.input_bits,
)));
}
let slice = inputs.as_slice()?;
let n_cols = self.n_columns;
let input_vec: Vec<bool> = slice.to_vec();
let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
// 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
let inputs_dev = self
.sp_gpu
.dev_ref()
.htod_sync_copy(&sdr_u8_all)
.map_err(|e| format!("H2D inputs: {e:?}"))?;
// 2. Allocate output buffers on device.
let mut cols_dev = self.sp_gpu.dev_ref()
.alloc_zeros::<u8>(t * n_cols)
.map_err(|e| format!("alloc cols: {e:?}"))?;
let mut anom_dev = self.sp_gpu.dev_ref()
.alloc_zeros::<f32>(t)
.map_err(|e| format!("alloc anom: {e:?}"))?;
// 3. Run T steps of SP + TM on GPU with NO per-step host sync.
self.sp_gpu.step_batch_with_tm(
&inputs_dev,
t,
self.input_bits,
learn,
&mut cols_dev,
&mut anom_dev,
&mut self.tm_gpu,
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
// 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
.dtoh_sync_copy(&cols_dev)
.map_err(|e| format!("D2H cols: {e:?}"))?;
let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
.dtoh_sync_copy(&anom_dev)
.map_err(|e| format!("D2H anom: {e:?}"))?;
Ok((cols_host, anom_host))
});
let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
.reshape([t, n_cols])
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
Ok((cols_arr, anom_arr))
}
/// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
/// write outputs directly into caller-allocated torch tensors. Skips the
/// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
/// two D2H copies at the end). This is the hot path for `train.py`.
///
/// Contract:
/// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
/// cols_cai.shape == (T, n_columns), dtype u8 (written)
/// anom_cai.shape == (T,), dtype f32 (written)
/// All three tensors must live on the SAME CUDA device as this region.
///
/// The torch tensors still own their memory — this method only wraps
/// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
/// impl can't free pytorch's allocator.
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
fn step_many_cuda(
&mut self,
py: Python<'_>,
sdr_cai: &Bound<'_, PyDict>,
cols_cai: &Bound<'_, PyDict>,
anom_cai: &Bound<'_, PyDict>,
learn: bool,
) -> PyResult<()> {
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
// typestr sanity. numpy u1 is what torch.uint8 exports.
if sdr_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
)));
}
if cols_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
)));
}
if anom_type != "<f4" && anom_type != "=f4" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
)));
}
// Shape validation.
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai shape {sdr_shape:?} != (T, {})",
self.input_bits,
)));
}
let t = sdr_shape[0];
if cols_shape != [t, self.n_columns] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai shape {cols_shape:?} != ({t}, {})",
self.n_columns,
)));
}
if anom_shape != [t] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai shape {anom_shape:?} != ({t},)",
)));
}
let dev = self.sp_gpu.dev_ref().clone();
let n_cols = self.n_columns;
let input_bits = self.input_bits;
let result = py.allow_threads(|| -> Result<(), String> {
// SAFETY:
// - ptrs came from torch CUDA tensors validated non-null by the
// __cuda_array_interface__ contract.
// - lens computed from validated shapes.
// - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
// Drop (which calls cuMemFree) never runs against torch memory.
// The underlying allocation is owned+freed by torch.
// - The slices are used only for the duration of this call;
// torch guarantees the backing tensors are live across it
// (Python holds refs on the wrapping tensors).
let inputs_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
});
let mut cols_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
});
let mut anom_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
});
self.sp_gpu.step_batch_with_tm(
&inputs_dev,
t,
input_bits,
learn,
&mut cols_dev,
&mut anom_dev,
&mut self.tm_gpu,
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
// Synchronize: kernel writes must be visible to the next torch
// op that reads cols/anom. Pytorch's default stream is stream 0,
// and cudarc launches on its own stream — a full device sync
// is the simplest correct barrier. (Could narrow to a stream
// wait event in PR 2.)
dev.synchronize().map_err(|e| format!("sync: {e:?}"))?;
Ok(())
});
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
Ok(())
}
/// Clear TM state on the GPU.
fn reset(&mut self) -> PyResult<()> {
self.tm_gpu.reset().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
})?;
self.fused_state.reset().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
})
}
/// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
/// forward (SP + TM all in one). Accepts torch CUDA tensors via
/// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
/// anomaly directly into caller-allocated torch tensors.
///
/// Semantics diverge from `step_many_cuda` in one important way: column
/// activation uses per-column threshold inhibition instead of global
/// top-K. The threshold is EMA-adapted per column toward the sparsity
/// target. See `docs/GPU_HTM.md` §Fused Kernel.
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
fn step_many_fused_cuda(
&mut self,
py: Python<'_>,
sdr_cai: &Bound<'_, PyDict>,
cols_cai: &Bound<'_, PyDict>,
anom_cai: &Bound<'_, PyDict>,
learn: bool,
) -> PyResult<()> {
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
if sdr_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
)));
}
if cols_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
)));
}
if anom_type != "<f4" && anom_type != "=f4" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
)));
}
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai shape {sdr_shape:?} != (T, {})",
self.input_bits,
)));
}
let t = sdr_shape[0];
if cols_shape != [t, self.n_columns] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai shape {cols_shape:?} != ({t}, {})",
self.n_columns,
)));
}
if anom_shape != [t] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai shape {anom_shape:?} != ({t},)",
)));
}
let dev = self.sp_gpu.dev_ref().clone();
let n_cols = self.n_columns;
let input_bits = self.input_bits;
let result = py.allow_threads(|| -> Result<(), String> {
let inputs_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
});
let mut cols_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
});
let mut anom_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
});
fused::launch_fused(
&mut self.sp_gpu,
&mut self.tm_gpu,
&mut self.fused_state,
&inputs_dev,
&mut cols_dev,
&mut anom_dev,
t,
input_bits,
learn,
).map_err(|e| format!("launch_fused: {e:?}"))?;
dev.synchronize().map_err(|e| format!("sync: {e:?}"))?;
Ok(())
});
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
Ok(())
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<HTMRegionGpu>()?;
Ok(())
}