Spaces:
Runtime error
Runtime error
| //! GPU backend for HTM. | |
| //! | |
| //! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the | |
| //! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations | |
| //! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back | |
| //! to the host in one shot. | |
| //! | |
| //! TM parity with the CPU reference is approximate: | |
| //! - Segment growth: winner = cell 0 of bursting column (CPU picks | |
| //! least-used-cell with RNG tiebreak). This is a pragmatic simplification | |
| //! for GPU atomicity; learning dynamics are preserved. | |
| //! - Permanences stored as i16 (scaled 0..32767). Rounding differs from | |
| //! f32 by <= 1 ULP of the scale factor (≈ 3e-5) — inside any meaningful | |
| //! HTM learning quantum. | |
| pub mod sp_gpu; | |
| pub mod tm_gpu; | |
| pub mod fused; | |
| mod tests; | |
| use std::mem::ManuallyDrop; | |
| use pyo3::prelude::*; | |
| use pyo3::types::{PyDict, PyTuple}; | |
| use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods}; | |
| use crate::region::HTMRegionCore; | |
| use crate::sp::SpatialPoolerConfig; | |
| use sp_gpu::SpatialPoolerGpu; | |
| use tm_gpu::TemporalMemoryGpu; | |
| use fused::FusedState; | |
| /// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict. | |
| /// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap | |
| /// torch-owned CUDA allocations zero-copy. | |
| fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> { | |
| // `data` is a (ptr: int, readonly: bool) tuple. | |
| let data_obj = cai.get_item("data")? | |
| .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?; | |
| let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into() | |
| .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?; | |
| let ptr: u64 = data_tup.get_item(0)?.extract()?; | |
| // `shape` is a tuple of ints. | |
| let shape_obj = cai.get_item("shape")? | |
| .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?; | |
| let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into() | |
| .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?; | |
| let shape: Vec<usize> = (0..shape_tup.len()) | |
| .map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>())) | |
| .collect::<PyResult<Vec<_>>>()?; | |
| // `typestr` (e.g. "|u1", "<f4"). | |
| let typestr_obj = cai.get_item("typestr")? | |
| .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?; | |
| let typestr: String = typestr_obj.extract()?; | |
| // Reject non-contiguous tensors — we don't handle strides. | |
| if let Some(strides) = cai.get_item("strides")? { | |
| if !strides.is_none() { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "CAI 'strides' must be None (tensor must be contiguous)", | |
| )); | |
| } | |
| } | |
| Ok((ptr, shape, typestr)) | |
| } | |
| /// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`. | |
| pub struct HTMRegionGpu { | |
| pub(super) sp_gpu: SpatialPoolerGpu, | |
| pub(super) tm_gpu: TemporalMemoryGpu, | |
| pub(super) fused_state: FusedState, | |
| pub(super) n_columns: usize, | |
| pub(super) input_bits: usize, | |
| pub(super) cells_per_column: usize, | |
| } | |
| impl HTMRegionGpu { | |
| fn new( | |
| input_bits: usize, | |
| n_columns: usize, | |
| cells_per_column: usize, | |
| seed: u64, | |
| ) -> PyResult<Self> { | |
| if input_bits == 0 || n_columns == 0 || cells_per_column == 0 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "input_bits, n_columns, cells_per_column must all be > 0", | |
| )); | |
| } | |
| // CPU reference for deterministic SP init. | |
| let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed); | |
| let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg; | |
| let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| { | |
| pyo3::exceptions::PyRuntimeError::new_err(format!( | |
| "GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}", | |
| sp_cfg.input_bits, sp_cfg.n_columns, | |
| )) | |
| })?; | |
| let dev = sp_gpu.dev_ref().clone(); | |
| let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| { | |
| pyo3::exceptions::PyRuntimeError::new_err(format!( | |
| "GPU TM init failed: {e:?}", | |
| )) | |
| })?; | |
| let initial_threshold = sp_gpu.initial_threshold_estimate(); | |
| let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold) | |
| .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!( | |
| "GPU fused state init failed: {e:?}", | |
| )))?; | |
| Ok(Self { | |
| sp_gpu, | |
| tm_gpu, | |
| fused_state, | |
| n_columns, | |
| input_bits, | |
| cells_per_column, | |
| }) | |
| } | |
| fn input_bits(&self) -> usize { self.input_bits } | |
| fn n_columns(&self) -> usize { self.n_columns } | |
| fn cells_per_column(&self) -> usize { self.cells_per_column } | |
| /// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays | |
| /// on device; only the final (T, n_cols) mask and (T,) anomaly are copied | |
| /// to the host at the end. | |
| fn step_many_gpu<'py>( | |
| &mut self, | |
| py: Python<'py>, | |
| inputs: PyReadonlyArray2<'py, bool>, | |
| learn: bool, | |
| ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> { | |
| let shape = inputs.shape(); | |
| if shape.len() != 2 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "inputs must be 2-D (T, input_bits)", | |
| )); | |
| } | |
| let t = shape[0]; | |
| let bits = shape[1]; | |
| if bits != self.input_bits { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "inputs last dim {bits} != expected input_bits {}", | |
| self.input_bits, | |
| ))); | |
| } | |
| let slice = inputs.as_slice()?; | |
| let n_cols = self.n_columns; | |
| let input_vec: Vec<bool> = slice.to_vec(); | |
| let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> { | |
| // 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384). | |
| let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect(); | |
| let inputs_dev = self | |
| .sp_gpu | |
| .dev_ref() | |
| .htod_sync_copy(&sdr_u8_all) | |
| .map_err(|e| format!("H2D inputs: {e:?}"))?; | |
| // 2. Allocate output buffers on device. | |
| let mut cols_dev = self.sp_gpu.dev_ref() | |
| .alloc_zeros::<u8>(t * n_cols) | |
| .map_err(|e| format!("alloc cols: {e:?}"))?; | |
| let mut anom_dev = self.sp_gpu.dev_ref() | |
| .alloc_zeros::<f32>(t) | |
| .map_err(|e| format!("alloc anom: {e:?}"))?; | |
| // 3. Run T steps of SP + TM on GPU with NO per-step host sync. | |
| self.sp_gpu.step_batch_with_tm( | |
| &inputs_dev, | |
| t, | |
| self.input_bits, | |
| learn, | |
| &mut cols_dev, | |
| &mut anom_dev, | |
| &mut self.tm_gpu, | |
| ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?; | |
| // 4. ONE D2H for the whole run (T * n_cols bytes + T floats). | |
| let cols_host: Vec<u8> = self.sp_gpu.dev_ref() | |
| .dtoh_sync_copy(&cols_dev) | |
| .map_err(|e| format!("D2H cols: {e:?}"))?; | |
| let anom_host: Vec<f32> = self.sp_gpu.dev_ref() | |
| .dtoh_sync_copy(&anom_dev) | |
| .map_err(|e| format!("D2H anom: {e:?}"))?; | |
| Ok((cols_host, anom_host)) | |
| }); | |
| let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?; | |
| let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect(); | |
| let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32) | |
| .reshape([t, n_cols]) | |
| .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?; | |
| let anom_arr = numpy::PyArray1::from_vec_bound(py, anom); | |
| Ok((cols_arr, anom_arr)) | |
| } | |
| /// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__, | |
| /// write outputs directly into caller-allocated torch tensors. Skips the | |
| /// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() + | |
| /// two D2H copies at the end). This is the hot path for `train.py`. | |
| /// | |
| /// Contract: | |
| /// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask) | |
| /// cols_cai.shape == (T, n_columns), dtype u8 (written) | |
| /// anom_cai.shape == (T,), dtype f32 (written) | |
| /// All three tensors must live on the SAME CUDA device as this region. | |
| /// | |
| /// The torch tensors still own their memory — this method only wraps | |
| /// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop | |
| /// impl can't free pytorch's allocator. | |
| fn step_many_cuda( | |
| &mut self, | |
| py: Python<'_>, | |
| sdr_cai: &Bound<'_, PyDict>, | |
| cols_cai: &Bound<'_, PyDict>, | |
| anom_cai: &Bound<'_, PyDict>, | |
| learn: bool, | |
| ) -> PyResult<()> { | |
| let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?; | |
| let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?; | |
| let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?; | |
| // typestr sanity. numpy u1 is what torch.uint8 exports. | |
| if sdr_type != "|u1" { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}", | |
| ))); | |
| } | |
| if cols_type != "|u1" { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "cols_cai typestr must be '|u1' (uint8), got {cols_type}", | |
| ))); | |
| } | |
| if anom_type != "<f4" && anom_type != "=f4" { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "anom_cai typestr must be '<f4' (float32), got {anom_type}", | |
| ))); | |
| } | |
| // Shape validation. | |
| if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "sdr_cai shape {sdr_shape:?} != (T, {})", | |
| self.input_bits, | |
| ))); | |
| } | |
| let t = sdr_shape[0]; | |
| if cols_shape != [t, self.n_columns] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "cols_cai shape {cols_shape:?} != ({t}, {})", | |
| self.n_columns, | |
| ))); | |
| } | |
| if anom_shape != [t] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "anom_cai shape {anom_shape:?} != ({t},)", | |
| ))); | |
| } | |
| let dev = self.sp_gpu.dev_ref().clone(); | |
| let n_cols = self.n_columns; | |
| let input_bits = self.input_bits; | |
| let result = py.allow_threads(|| -> Result<(), String> { | |
| // SAFETY: | |
| // - ptrs came from torch CUDA tensors validated non-null by the | |
| // __cuda_array_interface__ contract. | |
| // - lens computed from validated shapes. | |
| // - We wrap the returned CudaSlice in ManuallyDrop so cudarc's | |
| // Drop (which calls cuMemFree) never runs against torch memory. | |
| // The underlying allocation is owned+freed by torch. | |
| // - The slices are used only for the duration of this call; | |
| // torch guarantees the backing tensors are live across it | |
| // (Python holds refs on the wrapping tensors). | |
| let inputs_dev = ManuallyDrop::new(unsafe { | |
| dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits) | |
| }); | |
| let mut cols_dev = ManuallyDrop::new(unsafe { | |
| dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) | |
| }); | |
| let mut anom_dev = ManuallyDrop::new(unsafe { | |
| dev.upgrade_device_ptr::<f32>(anom_ptr, t) | |
| }); | |
| self.sp_gpu.step_batch_with_tm( | |
| &inputs_dev, | |
| t, | |
| input_bits, | |
| learn, | |
| &mut cols_dev, | |
| &mut anom_dev, | |
| &mut self.tm_gpu, | |
| ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?; | |
| // Synchronize: kernel writes must be visible to the next torch | |
| // op that reads cols/anom. Pytorch's default stream is stream 0, | |
| // and cudarc launches on its own stream — a full device sync | |
| // is the simplest correct barrier. (Could narrow to a stream | |
| // wait event in PR 2.) | |
| // No dev.synchronize() here: caller must explicitly sync via the | |
| // `device_sync()` method (or PyTorch auto-syncs when the output | |
| // tensor is next consumed). Removing the per-launch barrier lets | |
| // subsequent GPU work (mamba3 fwd, etc.) overlap in time. | |
| Ok(()) | |
| }); | |
| result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?; | |
| Ok(()) | |
| } | |
| /// Clear TM state on the GPU. | |
| fn reset(&mut self) -> PyResult<()> { | |
| self.tm_gpu.reset().map_err(|e| { | |
| pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}")) | |
| })?; | |
| self.fused_state.reset().map_err(|e| { | |
| pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}")) | |
| }) | |
| } | |
| /// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step | |
| /// forward (SP + TM all in one). Accepts torch CUDA tensors via | |
| /// `__cuda_array_interface__` (zero-copy). Writes active-column mask + | |
| /// anomaly directly into caller-allocated torch tensors. | |
| /// | |
| /// Semantics diverge from `step_many_cuda` in one important way: column | |
| /// activation uses per-column threshold inhibition instead of global | |
| /// top-K. The threshold is EMA-adapted per column toward the sparsity | |
| /// target. See `docs/GPU_HTM.md` §Fused Kernel. | |
| fn step_many_fused_cuda( | |
| &mut self, | |
| py: Python<'_>, | |
| sdr_cai: &Bound<'_, PyDict>, | |
| cols_cai: &Bound<'_, PyDict>, | |
| anom_cai: &Bound<'_, PyDict>, | |
| learn: bool, | |
| ) -> PyResult<()> { | |
| let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?; | |
| let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?; | |
| let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?; | |
| if sdr_type != "|u1" { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}", | |
| ))); | |
| } | |
| if cols_type != "|u1" { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "cols_cai typestr must be '|u1' (uint8), got {cols_type}", | |
| ))); | |
| } | |
| if anom_type != "<f4" && anom_type != "=f4" { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "anom_cai typestr must be '<f4' (float32), got {anom_type}", | |
| ))); | |
| } | |
| if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "sdr_cai shape {sdr_shape:?} != (T, {})", | |
| self.input_bits, | |
| ))); | |
| } | |
| let t = sdr_shape[0]; | |
| if cols_shape != [t, self.n_columns] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "cols_cai shape {cols_shape:?} != ({t}, {})", | |
| self.n_columns, | |
| ))); | |
| } | |
| if anom_shape != [t] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "anom_cai shape {anom_shape:?} != ({t},)", | |
| ))); | |
| } | |
| let dev = self.sp_gpu.dev_ref().clone(); | |
| let n_cols = self.n_columns; | |
| let input_bits = self.input_bits; | |
| let result = py.allow_threads(|| -> Result<(), String> { | |
| let inputs_dev = ManuallyDrop::new(unsafe { | |
| dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits) | |
| }); | |
| let mut cols_dev = ManuallyDrop::new(unsafe { | |
| dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols) | |
| }); | |
| let mut anom_dev = ManuallyDrop::new(unsafe { | |
| dev.upgrade_device_ptr::<f32>(anom_ptr, t) | |
| }); | |
| fused::launch_fused( | |
| &mut self.sp_gpu, | |
| &mut self.tm_gpu, | |
| &mut self.fused_state, | |
| &inputs_dev, | |
| &mut cols_dev, | |
| &mut anom_dev, | |
| t, | |
| input_bits, | |
| learn, | |
| ).map_err(|e| format!("launch_fused: {e:?}"))?; | |
| // No dev.synchronize() here: caller must explicitly sync via the | |
| // `device_sync()` method (or PyTorch auto-syncs when the output | |
| // tensor is next consumed). Removing the per-launch barrier lets | |
| // subsequent GPU work (mamba3 fwd, etc.) overlap in time. | |
| Ok(()) | |
| }); | |
| result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?; | |
| Ok(()) | |
| } | |
| /// Explicit device synchronization — the caller must invoke this after | |
| /// all batched `step_many_*_cuda` calls complete, before reading the | |
| /// output tensors from a different CUDA stream. Equivalent to the old | |
| /// per-call `dev.synchronize()` that was removed for overlap. | |
| fn device_sync(&self) -> PyResult<()> { | |
| let dev = self.sp_gpu.dev_ref(); | |
| dev.synchronize() | |
| .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?; | |
| Ok(()) | |
| } | |
| } | |
| /// Batch B regions into ONE cooperative kernel launch. Breaks through the | |
| /// CUDA cooperative-kernel device-level serialization: a single cooperative | |
| /// launch with grid.y=B processes all regions concurrently — ~B× speedup | |
| /// over B sequential launches. | |
| /// | |
| /// All regions must have the same config (input_bits, n_columns, | |
| /// cells_per_column). Each region keeps its independent GPU state. | |
| /// Does NOT sync; caller must invoke `device_sync()` on any region | |
| /// afterwards (or rely on a downstream torch op to auto-sync). | |
| fn step_batch_fused_cuda( | |
| py: Python<'_>, | |
| regions: Vec<Py<HTMRegionGpu>>, | |
| sdr_cais: Vec<Bound<'_, PyDict>>, | |
| cols_cais: Vec<Bound<'_, PyDict>>, | |
| anom_cais: Vec<Bound<'_, PyDict>>, | |
| learn: bool, | |
| ) -> PyResult<()> { | |
| let b = regions.len(); | |
| if b == 0 { | |
| return Err(pyo3::exceptions::PyValueError::new_err("regions is empty")); | |
| } | |
| if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "sdr_cais / cols_cais / anom_cais length must match regions", | |
| )); | |
| } | |
| // Parse all CAI dicts; collect device pointers. Validate shapes/dtypes. | |
| let mut sdr_ptrs = Vec::with_capacity(b); | |
| let mut cols_ptrs = Vec::with_capacity(b); | |
| let mut anom_ptrs = Vec::with_capacity(b); | |
| let (input_bits, n_columns, t) = { | |
| let r0 = regions[0].bind(py).borrow(); | |
| (r0.input_bits, r0.n_columns, { | |
| let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?; | |
| if sh.len() != 2 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"), | |
| )); | |
| } | |
| sh[0] | |
| }) | |
| }; | |
| for i in 0..b { | |
| let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?; | |
| let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?; | |
| let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?; | |
| if sdr_type != "|u1" || cols_type != "|u1" { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "sdr/cols typestr must be '|u1' (uint8)", | |
| )); | |
| } | |
| if anom_type != "<f4" && anom_type != "=f4" { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "anom typestr must be '<f4' (float32)", | |
| )); | |
| } | |
| if sdr_shape != [t, input_bits] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})" | |
| ))); | |
| } | |
| if cols_shape != [t, n_columns] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})" | |
| ))); | |
| } | |
| if anom_shape != [t] { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "anom[{i}] shape {anom_shape:?} != ({t},)" | |
| ))); | |
| } | |
| sdr_ptrs.push(sdr_ptr); | |
| cols_ptrs.push(cols_ptr); | |
| anom_ptrs.push(anom_ptr); | |
| } | |
| // Exclusively borrow each region. PyRefMut guarantees uniqueness. | |
| let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> = | |
| regions.iter().map(|p| p.bind(py).borrow_mut()).collect(); | |
| // Collect raw mutable pointers — each PyRefMut exclusively borrows its | |
| // region for the lifetime of this call, so pointers stay valid and | |
| // unique. launch_fused_batched_raw only dereferences one region at a | |
| // time, not constructing an aliased slice. | |
| let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs | |
| .iter_mut() | |
| .map(|r| &mut **r as *mut HTMRegionGpu) | |
| .collect(); | |
| // No allow_threads: raw pointers aren't Send. The launch is GPU-queued | |
| // and sync'd downstream; holding the GIL for the duration is cheap. | |
| fused::launch_fused_batched_raw( | |
| &raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs, | |
| t, input_bits, learn, | |
| ) | |
| .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?; | |
| Ok(()) | |
| } | |
| pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { | |
| m.add_class::<HTMRegionGpu>()?; | |
| m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?; | |
| Ok(()) | |
| } | |