Spaces:
Runtime error
Runtime error
| //! pyo3 bindings for HTMRegion (Numenta BAMI-spec HTM). | |
| //! | |
| //! Exposed class: | |
| //! HTMRegion(input_bits, n_columns, cells_per_column, seed) -> HTMRegion | |
| //! .step(input_sdr: np.ndarray[bool; input_bits], learn: bool = True) | |
| //! -> (active_columns: np.ndarray[bool; n_columns], | |
| //! active_cells: np.ndarray[bool; n_columns*cells_per_column], | |
| //! predicted_cells:np.ndarray[bool; n_columns*cells_per_column], | |
| //! anomaly: float) | |
| //! .reset() | |
| //! .n_columns -> int | |
| //! .cells_per_column -> int | |
| //! .input_bits -> int | |
| //! | |
| //! GIL is dropped during the heavy compute via `py.allow_threads(...)` so the | |
| //! region is effectively `Send` for Python-side threading. | |
| // pyo3 0.22 `#[pymethods]` expansion inserts an implicit `.into()` on the | |
| // returned `Result` to normalise the error type, which clippy reports as | |
| // `useless_conversion` when our methods already return `PyErr`. The emitted | |
| // code sits outside the user-written impl, so item-level allows don't reach | |
| // it; the module-wide allow is the documented workaround. | |
| mod region; | |
| mod sp; | |
| mod tm; | |
| mod gpu; | |
| use numpy::{ | |
| IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2, | |
| PyUntypedArrayMethods, | |
| }; | |
| use pyo3::prelude::*; | |
| use crate::region::HTMRegionCore; | |
| /// Result of one HTM step: (active_columns, active_cells, predicted_cells, anomaly). | |
| type StepOutput<'py> = ( | |
| Bound<'py, PyArray1<bool>>, | |
| Bound<'py, PyArray1<bool>>, | |
| Bound<'py, PyArray1<bool>>, | |
| f32, | |
| ); | |
| pub struct HTMRegion { | |
| core: HTMRegionCore, | |
| } | |
| impl HTMRegion { | |
| /// Create a new HTM region. | |
| /// | |
| /// Args: | |
| /// input_bits: length of binary input SDR | |
| /// n_columns: number of mini-columns in the SP (e.g. 2048) | |
| /// cells_per_column: cells per column in the TM (e.g. 32) | |
| /// seed: RNG seed for reproducibility | |
| fn new( | |
| input_bits: usize, | |
| n_columns: usize, | |
| cells_per_column: usize, | |
| seed: u64, | |
| ) -> PyResult<Self> { | |
| if input_bits == 0 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "input_bits must be > 0", | |
| )); | |
| } | |
| if n_columns == 0 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "n_columns must be > 0", | |
| )); | |
| } | |
| if cells_per_column == 0 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "cells_per_column must be > 0", | |
| )); | |
| } | |
| Ok(Self { | |
| core: HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed), | |
| }) | |
| } | |
| fn input_bits(&self) -> usize { self.core.sp.cfg.input_bits } | |
| fn n_columns(&self) -> usize { self.core.sp.cfg.n_columns } | |
| fn cells_per_column(&self) -> usize { self.core.tm.cfg.cells_per_column } | |
| /// Process one timestep. | |
| /// | |
| /// Args: | |
| /// input_sdr: 1-D numpy boolean array of length `input_bits`. | |
| /// learn: if True, update SP permanences and TM synapses. | |
| /// | |
| /// Returns: | |
| /// (active_columns, active_cells, predicted_cells, anomaly) | |
| fn step<'py>( | |
| &mut self, | |
| py: Python<'py>, | |
| input_sdr: PyReadonlyArray1<'py, bool>, | |
| learn: bool, | |
| ) -> PyResult<StepOutput<'py>> { | |
| let expected = self.core.sp.cfg.input_bits; | |
| let slice = input_sdr.as_slice()?; | |
| let got = slice.len(); | |
| if got != expected { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "input_sdr length {got} != expected input_bits {expected}", | |
| ))); | |
| } | |
| // Copy input to an owned Vec so we can drop the GIL. | |
| let input_vec: Vec<bool> = slice.to_vec(); | |
| let (active_cols, active_cells, predicted_cells, anomaly) = | |
| py.allow_threads(|| self.core.step(&input_vec, learn)); | |
| let a: Bound<'py, PyArray1<bool>> = active_cols.into_pyarray_bound(py); | |
| let c: Bound<'py, PyArray1<bool>> = active_cells.into_pyarray_bound(py); | |
| let p: Bound<'py, PyArray1<bool>> = predicted_cells.into_pyarray_bound(py); | |
| Ok((a, c, p, anomaly)) | |
| } | |
| /// Clear TM predictive state. Does NOT unlearn synapses. | |
| fn reset(&mut self) { self.core.reset(); } | |
| /// Process T timesteps from a `(T, input_bits)` bool ndarray. | |
| /// | |
| /// Returns: | |
| /// cols: (T, n_columns) float32 0/1 active-column mask | |
| /// anom: (T,) float32 anomaly scores | |
| /// | |
| /// Single GIL release for the whole pass, avoiding T × Python-call overhead. | |
| fn step_many<'py>( | |
| &mut self, | |
| py: Python<'py>, | |
| inputs: PyReadonlyArray2<'py, bool>, | |
| learn: bool, | |
| ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> { | |
| let shape = inputs.shape(); | |
| if shape.len() != 2 { | |
| return Err(pyo3::exceptions::PyValueError::new_err( | |
| "inputs must be 2-D (T, input_bits)", | |
| )); | |
| } | |
| let t = shape[0]; | |
| let bits = shape[1]; | |
| let expected = self.core.sp.cfg.input_bits; | |
| if bits != expected { | |
| return Err(pyo3::exceptions::PyValueError::new_err(format!( | |
| "inputs last dim {bits} != expected input_bits {expected}", | |
| ))); | |
| } | |
| let slice = inputs.as_slice()?; | |
| let n_cols = self.core.sp.cfg.n_columns; | |
| // Own the input buffer so we can drop the GIL. | |
| let input_vec: Vec<bool> = slice.to_vec(); | |
| let (cols_u8, anom) = | |
| py.allow_threads(|| self.core.step_many(&input_vec, bits, t, learn)); | |
| // Convert u8 mask to f32 for direct numpy consumption. | |
| let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect(); | |
| // Build (T, n_cols) and (T,) arrays. | |
| let cols_arr = | |
| numpy::PyArray1::from_vec_bound(py, cols_f32) | |
| .reshape([t, n_cols]) | |
| .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?; | |
| let anom_arr = numpy::PyArray1::from_vec_bound(py, anom); | |
| Ok((cols_arr, anom_arr)) | |
| } | |
| } | |
| /// Python module entry point. | |
| fn htm_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { | |
| m.add_class::<HTMRegion>()?; | |
| { | |
| gpu::register(m)?; | |
| } | |
| m.add("__version__", env!("CARGO_PKG_VERSION"))?; | |
| Ok(()) | |
| } | |