Jackoatmon's picture
Update Feather a10g-large training runtime image
2a882ca verified
//! GPU backend for HTM.
//!
//! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
//! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
//! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
//! to the host in one shot.
//!
//! TM parity with the CPU reference is approximate:
//! - Segment growth: winner = cell 0 of bursting column (CPU picks
//! least-used-cell with RNG tiebreak). This is a pragmatic simplification
//! for GPU atomicity; learning dynamics are preserved.
//! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
//! f32 by <= 1 ULP of the scale factor (≈ 3e-5) — inside any meaningful
//! HTM learning quantum.
#![cfg(feature = "gpu")]
pub mod sp_gpu;
pub mod tm_gpu;
pub mod fused;
#[cfg(test)]
mod tests;
use std::mem::ManuallyDrop;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
use crate::region::HTMRegionCore;
use crate::sp::SpatialPoolerConfig;
use sp_gpu::SpatialPoolerGpu;
use tm_gpu::TemporalMemoryGpu;
use fused::FusedState;
/// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
/// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
/// torch-owned CUDA allocations zero-copy.
fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
// `data` is a (ptr: int, readonly: bool) tuple.
let data_obj = cai.get_item("data")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
let ptr: u64 = data_tup.get_item(0)?.extract()?;
// `shape` is a tuple of ints.
let shape_obj = cai.get_item("shape")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
let shape: Vec<usize> = (0..shape_tup.len())
.map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
.collect::<PyResult<Vec<_>>>()?;
// `typestr` (e.g. "|u1", "<f4").
let typestr_obj = cai.get_item("typestr")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
let typestr: String = typestr_obj.extract()?;
// Reject non-contiguous tensors — we don't handle strides.
if let Some(strides) = cai.get_item("strides")? {
if !strides.is_none() {
return Err(pyo3::exceptions::PyValueError::new_err(
"CAI 'strides' must be None (tensor must be contiguous)",
));
}
}
Ok((ptr, shape, typestr))
}
/// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
#[pyclass(module = "htm_rust")]
pub struct HTMRegionGpu {
pub(super) sp_gpu: SpatialPoolerGpu,
pub(super) tm_gpu: TemporalMemoryGpu,
pub(super) fused_state: FusedState,
pub(super) n_columns: usize,
pub(super) input_bits: usize,
pub(super) cells_per_column: usize,
}
#[pymethods]
impl HTMRegionGpu {
#[new]
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
fn new(
input_bits: usize,
n_columns: usize,
cells_per_column: usize,
seed: u64,
) -> PyResult<Self> {
if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"input_bits, n_columns, cells_per_column must all be > 0",
));
}
// CPU reference for deterministic SP init.
let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
sp_cfg.input_bits, sp_cfg.n_columns,
))
})?;
let dev = sp_gpu.dev_ref().clone();
let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"GPU TM init failed: {e:?}",
))
})?;
let initial_threshold = sp_gpu.initial_threshold_estimate();
let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
"GPU fused state init failed: {e:?}",
)))?;
Ok(Self {
sp_gpu,
tm_gpu,
fused_state,
n_columns,
input_bits,
cells_per_column,
})
}
#[getter] fn input_bits(&self) -> usize { self.input_bits }
#[getter] fn n_columns(&self) -> usize { self.n_columns }
#[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
/// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
/// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
/// to the host at the end.
#[pyo3(signature = (inputs, learn=true))]
fn step_many_gpu<'py>(
&mut self,
py: Python<'py>,
inputs: PyReadonlyArray2<'py, bool>,
learn: bool,
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
let shape = inputs.shape();
if shape.len() != 2 {
return Err(pyo3::exceptions::PyValueError::new_err(
"inputs must be 2-D (T, input_bits)",
));
}
let t = shape[0];
let bits = shape[1];
if bits != self.input_bits {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"inputs last dim {bits} != expected input_bits {}",
self.input_bits,
)));
}
let slice = inputs.as_slice()?;
let n_cols = self.n_columns;
let input_vec: Vec<bool> = slice.to_vec();
let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
// 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
let inputs_dev = self
.sp_gpu
.dev_ref()
.htod_sync_copy(&sdr_u8_all)
.map_err(|e| format!("H2D inputs: {e:?}"))?;
// 2. Allocate output buffers on device.
let mut cols_dev = self.sp_gpu.dev_ref()
.alloc_zeros::<u8>(t * n_cols)
.map_err(|e| format!("alloc cols: {e:?}"))?;
let mut anom_dev = self.sp_gpu.dev_ref()
.alloc_zeros::<f32>(t)
.map_err(|e| format!("alloc anom: {e:?}"))?;
// 3. Run T steps of SP + TM on GPU with NO per-step host sync.
self.sp_gpu.step_batch_with_tm(
&inputs_dev,
t,
self.input_bits,
learn,
&mut cols_dev,
&mut anom_dev,
&mut self.tm_gpu,
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
// 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
.dtoh_sync_copy(&cols_dev)
.map_err(|e| format!("D2H cols: {e:?}"))?;
let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
.dtoh_sync_copy(&anom_dev)
.map_err(|e| format!("D2H anom: {e:?}"))?;
Ok((cols_host, anom_host))
});
let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
.reshape([t, n_cols])
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
Ok((cols_arr, anom_arr))
}
/// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
/// write outputs directly into caller-allocated torch tensors. Skips the
/// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
/// two D2H copies at the end). This is the hot path for `train.py`.
///
/// Contract:
/// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
/// cols_cai.shape == (T, n_columns), dtype u8 (written)
/// anom_cai.shape == (T,), dtype f32 (written)
/// All three tensors must live on the SAME CUDA device as this region.
///
/// The torch tensors still own their memory — this method only wraps
/// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
/// impl can't free pytorch's allocator.
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
fn step_many_cuda(
&mut self,
py: Python<'_>,
sdr_cai: &Bound<'_, PyDict>,
cols_cai: &Bound<'_, PyDict>,
anom_cai: &Bound<'_, PyDict>,
learn: bool,
) -> PyResult<()> {
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
// typestr sanity. numpy u1 is what torch.uint8 exports.
if sdr_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
)));
}
if cols_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
)));
}
if anom_type != "<f4" && anom_type != "=f4" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
)));
}
// Shape validation.
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai shape {sdr_shape:?} != (T, {})",
self.input_bits,
)));
}
let t = sdr_shape[0];
if cols_shape != [t, self.n_columns] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai shape {cols_shape:?} != ({t}, {})",
self.n_columns,
)));
}
if anom_shape != [t] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai shape {anom_shape:?} != ({t},)",
)));
}
let dev = self.sp_gpu.dev_ref().clone();
let n_cols = self.n_columns;
let input_bits = self.input_bits;
let result = py.allow_threads(|| -> Result<(), String> {
// SAFETY:
// - ptrs came from torch CUDA tensors validated non-null by the
// __cuda_array_interface__ contract.
// - lens computed from validated shapes.
// - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
// Drop (which calls cuMemFree) never runs against torch memory.
// The underlying allocation is owned+freed by torch.
// - The slices are used only for the duration of this call;
// torch guarantees the backing tensors are live across it
// (Python holds refs on the wrapping tensors).
let inputs_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
});
let mut cols_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
});
let mut anom_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
});
self.sp_gpu.step_batch_with_tm(
&inputs_dev,
t,
input_bits,
learn,
&mut cols_dev,
&mut anom_dev,
&mut self.tm_gpu,
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
// Synchronize: kernel writes must be visible to the next torch
// op that reads cols/anom. Pytorch's default stream is stream 0,
// and cudarc launches on its own stream — a full device sync
// is the simplest correct barrier. (Could narrow to a stream
// wait event in PR 2.)
// No dev.synchronize() here: caller must explicitly sync via the
// `device_sync()` method (or PyTorch auto-syncs when the output
// tensor is next consumed). Removing the per-launch barrier lets
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
Ok(())
});
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
Ok(())
}
/// Clear TM state on the GPU.
fn reset(&mut self) -> PyResult<()> {
self.tm_gpu.reset().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
})?;
self.fused_state.reset().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
})
}
/// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
/// forward (SP + TM all in one). Accepts torch CUDA tensors via
/// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
/// anomaly directly into caller-allocated torch tensors.
///
/// Semantics diverge from `step_many_cuda` in one important way: column
/// activation uses per-column threshold inhibition instead of global
/// top-K. The threshold is EMA-adapted per column toward the sparsity
/// target. See `docs/GPU_HTM.md` §Fused Kernel.
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
fn step_many_fused_cuda(
&mut self,
py: Python<'_>,
sdr_cai: &Bound<'_, PyDict>,
cols_cai: &Bound<'_, PyDict>,
anom_cai: &Bound<'_, PyDict>,
learn: bool,
) -> PyResult<()> {
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
if sdr_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
)));
}
if cols_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
)));
}
if anom_type != "<f4" && anom_type != "=f4" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
)));
}
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr_cai shape {sdr_shape:?} != (T, {})",
self.input_bits,
)));
}
let t = sdr_shape[0];
if cols_shape != [t, self.n_columns] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols_cai shape {cols_shape:?} != ({t}, {})",
self.n_columns,
)));
}
if anom_shape != [t] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom_cai shape {anom_shape:?} != ({t},)",
)));
}
let dev = self.sp_gpu.dev_ref().clone();
let n_cols = self.n_columns;
let input_bits = self.input_bits;
let result = py.allow_threads(|| -> Result<(), String> {
let inputs_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
});
let mut cols_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
});
let mut anom_dev = ManuallyDrop::new(unsafe {
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
});
fused::launch_fused(
&mut self.sp_gpu,
&mut self.tm_gpu,
&mut self.fused_state,
&inputs_dev,
&mut cols_dev,
&mut anom_dev,
t,
input_bits,
learn,
).map_err(|e| format!("launch_fused: {e:?}"))?;
// No dev.synchronize() here: caller must explicitly sync via the
// `device_sync()` method (or PyTorch auto-syncs when the output
// tensor is next consumed). Removing the per-launch barrier lets
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
Ok(())
});
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
Ok(())
}
/// Explicit device synchronization — the caller must invoke this after
/// all batched `step_many_*_cuda` calls complete, before reading the
/// output tensors from a different CUDA stream. Equivalent to the old
/// per-call `dev.synchronize()` that was removed for overlap.
fn device_sync(&self) -> PyResult<()> {
let dev = self.sp_gpu.dev_ref();
dev.synchronize()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
Ok(())
}
}
/// Batch B regions into ONE cooperative kernel launch. Breaks through the
/// CUDA cooperative-kernel device-level serialization: a single cooperative
/// launch with grid.y=B processes all regions concurrently — ~B× speedup
/// over B sequential launches.
///
/// All regions must have the same config (input_bits, n_columns,
/// cells_per_column). Each region keeps its independent GPU state.
/// Does NOT sync; caller must invoke `device_sync()` on any region
/// afterwards (or rely on a downstream torch op to auto-sync).
#[pyfunction]
#[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
fn step_batch_fused_cuda(
py: Python<'_>,
regions: Vec<Py<HTMRegionGpu>>,
sdr_cais: Vec<Bound<'_, PyDict>>,
cols_cais: Vec<Bound<'_, PyDict>>,
anom_cais: Vec<Bound<'_, PyDict>>,
learn: bool,
) -> PyResult<()> {
let b = regions.len();
if b == 0 {
return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
}
if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
return Err(pyo3::exceptions::PyValueError::new_err(
"sdr_cais / cols_cais / anom_cais length must match regions",
));
}
// Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
let mut sdr_ptrs = Vec::with_capacity(b);
let mut cols_ptrs = Vec::with_capacity(b);
let mut anom_ptrs = Vec::with_capacity(b);
let (input_bits, n_columns, t) = {
let r0 = regions[0].bind(py).borrow();
(r0.input_bits, r0.n_columns, {
let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
if sh.len() != 2 {
return Err(pyo3::exceptions::PyValueError::new_err(
format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
));
}
sh[0]
})
};
for i in 0..b {
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
if sdr_type != "|u1" || cols_type != "|u1" {
return Err(pyo3::exceptions::PyValueError::new_err(
"sdr/cols typestr must be '|u1' (uint8)",
));
}
if anom_type != "<f4" && anom_type != "=f4" {
return Err(pyo3::exceptions::PyValueError::new_err(
"anom typestr must be '<f4' (float32)",
));
}
if sdr_shape != [t, input_bits] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
)));
}
if cols_shape != [t, n_columns] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
)));
}
if anom_shape != [t] {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"anom[{i}] shape {anom_shape:?} != ({t},)"
)));
}
sdr_ptrs.push(sdr_ptr);
cols_ptrs.push(cols_ptr);
anom_ptrs.push(anom_ptr);
}
// Exclusively borrow each region. PyRefMut guarantees uniqueness.
let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
// Collect raw mutable pointers — each PyRefMut exclusively borrows its
// region for the lifetime of this call, so pointers stay valid and
// unique. launch_fused_batched_raw only dereferences one region at a
// time, not constructing an aliased slice.
let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
.iter_mut()
.map(|r| &mut **r as *mut HTMRegionGpu)
.collect();
// No allow_threads: raw pointers aren't Send. The launch is GPU-queued
// and sync'd downstream; holding the GIL for the duration is cheap.
fused::launch_fused_batched_raw(
&raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
t, input_bits, learn,
)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
Ok(())
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<HTMRegionGpu>()?;
m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
Ok(())
}