//! GPU Temporal Memory. //! //! Flat device storage. Pre-allocated segment slab: //! n_cells = n_columns * cells_per_column //! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL //! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT //! //! Defaults (CPU parity targets relaxed on GPU to keep memory tractable): //! MAX_SEGMENTS_PER_CELL = 16 //! MAX_SYN_PER_SEGMENT = 32 //! //! At n_cells = 65536: //! n_segments_max = 1_048_576 (~1M) //! n_synapses_max = 33_554_432 (~33M) //! Storage: //! syn_presyn : u32 × 33M = 128 MB //! syn_perm : i16 × 33M = 64 MB //! seg_cell : u32 × 1M = 4 MB //! seg_syn_n : u32 × 1M = 4 MB //! misc bitsets etc ~ <1 MB //! ------------------------------- //! Total per region ~200 MB //! //! Permanences are stored as i16 scaled by 32767 (→ [0, 32767] represents //! [0.0, 1.0]). inc/dec are provided pre-scaled. use std::sync::Arc; use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig}; use cudarc::nvrtc::Ptx; /// Packed config struct passed by value to TM kernels to stay under /// cudarc's 12-tuple launch limit. Layout must match the C-side /// `TmConfig` struct declared in each kernel. #[repr(C)] #[derive(Clone, Copy)] pub struct TmConfig { pub activation_threshold: u32, pub learning_threshold: u32, pub cells_per_column: u32, pub synapses_per_segment: u32, pub n_segments: u32, pub n_cells: u32, pub max_segments_per_cell: u32, pub max_new_synapses: u32, pub conn_thr_i16: i32, // i16 widened to i32 for alignment pub perm_inc_i16: i32, pub perm_dec_i16: i32, pub predicted_seg_dec_i16: i32, pub initial_perm_i16: i32, pub iter_seed: u32, pub n_cols: u32, pub bits_words: u32, } unsafe impl DeviceRepr for TmConfig {} // Embedded PTX. const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx")); const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx")); const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx")); const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx")); const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx")); const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx")); const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx")); /// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model: /// n_cells = 2048 × 32 = 65_536 /// n_segments_max = n_cells × MAX_SEGMENTS_PER_CELL /// n_synapses_max = n_segments_max × MAX_SYN_PER_SEGMENT /// /// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region). /// The training loop runs with `reset_each_forward=True`, so segment counts /// per window stay well below 32K (typical: ~n_cols new segs per step until /// the first matching segment is reused; in a 2048-step window that plateaus /// around ~5K total live segments). The 262K ceiling is generous headroom. pub const MAX_SEGMENTS_PER_CELL: usize = 4; pub const MAX_SYN_PER_SEGMENT: usize = 20; const PERM_SCALE: f32 = 32767.0; fn perm_f32_to_i16(x: f32) -> i16 { let clamped = x.clamp(0.0, 1.0); (clamped * PERM_SCALE).round() as i16 } pub struct TemporalMemoryGpu { dev: Arc, // Config mirror pub n_columns: usize, pub cells_per_column: usize, pub activation_threshold: u32, pub learning_threshold: u32, pub initial_perm_i16: i16, pub conn_thr_i16: i16, pub perm_inc_i16: i16, pub perm_dec_i16: i16, pub predicted_seg_dec_i16: i16, pub max_new_synapse_count: u32, // Sizes pub n_cells: usize, pub n_segments_max: usize, pub bits_words: usize, // n_cells / 32 // Persistent device buffers seg_cell_id: CudaSlice, seg_syn_count: CudaSlice, syn_presyn: CudaSlice, syn_perm: CudaSlice, cell_seg_count: CudaSlice, cell_active_bits: CudaSlice, cell_winner_bits: CudaSlice, cell_predictive_bits: CudaSlice, prev_active_bits: CudaSlice, prev_winner_bits: CudaSlice, col_predicted: CudaSlice, seg_num_active_conn: CudaSlice, seg_num_active_pot: CudaSlice, unpredicted_count: CudaSlice, burst_cols_flat: CudaSlice, burst_cols_count: CudaSlice, col_best_match: CudaSlice, iter_counter: u32, } impl TemporalMemoryGpu { pub fn new( dev: Arc, n_columns: usize, cells_per_column: usize, ) -> Result { let n_cells = n_columns * cells_per_column; assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets"); let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL; let bits_words = n_cells / 32; // Numenta defaults. let activation_threshold = 15u32; let learning_threshold = 13u32; let initial_perm_i16 = perm_f32_to_i16(0.21); let conn_thr_i16 = perm_f32_to_i16(0.50); let perm_inc_i16 = perm_f32_to_i16(0.10); let perm_dec_i16 = perm_f32_to_i16(0.10); let predicted_seg_dec_i16 = perm_f32_to_i16(0.10); let max_new_synapse_count = 20u32; // Allocate buffers. let seg_cell_id_host: Vec = vec![u32::MAX; n_segments_max]; let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?; let seg_syn_count = dev.alloc_zeros::(n_segments_max)?; let syn_presyn = dev.alloc_zeros::(n_segments_max * MAX_SYN_PER_SEGMENT)?; let syn_perm = dev.alloc_zeros::(n_segments_max * MAX_SYN_PER_SEGMENT)?; let cell_seg_count = dev.alloc_zeros::(n_cells)?; let cell_active_bits = dev.alloc_zeros::(bits_words)?; let cell_winner_bits = dev.alloc_zeros::(bits_words)?; let cell_predictive_bits = dev.alloc_zeros::(bits_words)?; let prev_active_bits = dev.alloc_zeros::(bits_words)?; let prev_winner_bits = dev.alloc_zeros::(bits_words)?; let col_predicted = dev.alloc_zeros::(n_columns)?; let seg_num_active_conn = dev.alloc_zeros::(n_segments_max)?; let seg_num_active_pot = dev.alloc_zeros::(n_segments_max)?; let unpredicted_count = dev.alloc_zeros::(1)?; // Bursting columns for one step bounded by n_columns. let burst_cols_flat = dev.alloc_zeros::(n_columns)?; let burst_cols_count = dev.alloc_zeros::(1)?; let col_best_match = dev.alloc_zeros::(n_columns)?; // Load PTX modules. let modules = [ ("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"), ("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"), ("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"), ("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"), ("htm_tm_grow", PTX_TM_GROW, "tm_grow"), ("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"), ("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"), ]; for (modname, ptx, fnname) in modules { if dev.get_func(modname, fnname).is_none() { dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?; } } Ok(Self { dev, n_columns, cells_per_column, activation_threshold, learning_threshold, initial_perm_i16, conn_thr_i16, perm_inc_i16, perm_dec_i16, predicted_seg_dec_i16, max_new_synapse_count, n_cells, n_segments_max, bits_words, seg_cell_id, seg_syn_count, syn_presyn, syn_perm, cell_seg_count, cell_active_bits, cell_winner_bits, cell_predictive_bits, prev_active_bits, prev_winner_bits, col_predicted, seg_num_active_conn, seg_num_active_pot, unpredicted_count, burst_cols_flat, burst_cols_count, col_best_match, iter_counter: 0, }) } // --- Fused-path accessors --- pub fn seg_cell_id_accessor(&self) -> &CudaSlice { &self.seg_cell_id } pub fn seg_syn_count_accessor(&self) -> &CudaSlice { &self.seg_syn_count } pub fn syn_presyn_accessor(&self) -> &CudaSlice { &self.syn_presyn } pub fn syn_perm_accessor(&self) -> &CudaSlice { &self.syn_perm } pub fn cell_seg_count_accessor(&self) -> &CudaSlice { &self.cell_seg_count } /// Hard reset — clear everything (predictive + active + segments). pub fn reset(&mut self) -> Result<(), DriverError> { // Restore "unused" sentinel in seg_cell_id. let unused_host: Vec = vec![u32::MAX; self.n_segments_max]; self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?; self.dev.memset_zeros(&mut self.seg_syn_count)?; self.dev.memset_zeros(&mut self.cell_seg_count)?; self.dev.memset_zeros(&mut self.cell_active_bits)?; self.dev.memset_zeros(&mut self.cell_winner_bits)?; self.dev.memset_zeros(&mut self.cell_predictive_bits)?; self.dev.memset_zeros(&mut self.prev_active_bits)?; self.dev.memset_zeros(&mut self.prev_winner_bits)?; self.dev.memset_zeros(&mut self.col_best_match)?; self.iter_counter = 0; Ok(()) } fn build_cfg(&self) -> TmConfig { TmConfig { activation_threshold: self.activation_threshold, learning_threshold: self.learning_threshold, cells_per_column: self.cells_per_column as u32, synapses_per_segment: MAX_SYN_PER_SEGMENT as u32, n_segments: self.n_segments_max as u32, n_cells: self.n_cells as u32, max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32, max_new_synapses: self.max_new_synapse_count, conn_thr_i16: self.conn_thr_i16 as i32, perm_inc_i16: self.perm_inc_i16 as i32, perm_dec_i16: self.perm_dec_i16 as i32, predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32, initial_perm_i16: self.initial_perm_i16 as i32, iter_seed: self.iter_counter, n_cols: self.n_columns as u32, bits_words: self.bits_words as u32, } } /// Run one TM step on the GPU. Takes the SP active-column mask (u8, already /// on device) and writes `anomaly_out[t_slot]`. pub fn step( &mut self, sp_active_mask: &CudaSlice, anomaly_out: &mut CudaSlice, t_slot: u32, learn: bool, ) -> Result<(), DriverError> { let n_cells = self.n_cells; let n_cols = self.n_columns; let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap(); let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap(); let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap(); let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap(); let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap(); let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap(); let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap(); self.iter_counter = self.iter_counter.wrapping_add(1); let cfg_val = self.build_cfg(); // 0. Per-step reset. let reset_words = self.bits_words.max(n_cols); let reset_cfg = LaunchConfig { grid_dim: (((reset_words + 255) / 256) as u32, 1, 1), block_dim: (256, 1, 1), shared_mem_bytes: 0, }; unsafe { reset_fn.clone().launch( reset_cfg, ( &mut self.cell_active_bits, &mut self.cell_winner_bits, &mut self.cell_predictive_bits, &mut self.prev_active_bits, &mut self.prev_winner_bits, &mut self.col_predicted, &mut self.unpredicted_count, &mut self.burst_cols_count, &mut self.col_best_match, self.bits_words as u32, n_cols as u32, ), )?; } // 1. Predict (grid = n_cells; each block iterates its cell's segments). let predict_cfg = LaunchConfig { grid_dim: (n_cells as u32, 1, 1), block_dim: (32, 1, 1), shared_mem_bytes: 0, }; unsafe { predict_fn.clone().launch( predict_cfg, ( &self.seg_cell_id, &self.seg_syn_count, &self.syn_presyn, &self.syn_perm, &self.prev_active_bits, &mut self.cell_predictive_bits, &mut self.col_predicted, &mut self.seg_num_active_conn, &mut self.seg_num_active_pot, &mut self.col_best_match, &self.cell_seg_count, cfg_val, ), )?; } // 2. Activate. let activate_cfg = LaunchConfig { grid_dim: (((n_cols + 255) / 256) as u32, 1, 1), block_dim: (256, 1, 1), shared_mem_bytes: 0, }; unsafe { activate_fn.clone().launch( activate_cfg, ( sp_active_mask, &self.col_predicted, &self.cell_predictive_bits, &mut self.cell_active_bits, &mut self.cell_winner_bits, &mut self.unpredicted_count, &mut self.burst_cols_flat, &mut self.burst_cols_count, cfg_val, ), )?; } // 3. Anomaly. let anom_cfg = LaunchConfig { grid_dim: (1, 1, 1), block_dim: (256, 1, 1), shared_mem_bytes: 0, }; unsafe { anom_fn.clone().launch( anom_cfg, ( sp_active_mask, &self.unpredicted_count, anomaly_out, t_slot, n_cols as u32, ), )?; } if learn { // 4. Reinforce (grid = n_cells). let learn_cfg = LaunchConfig { grid_dim: (n_cells as u32, 1, 1), block_dim: (32, 1, 1), shared_mem_bytes: 0, }; unsafe { learn_fn.clone().launch( learn_cfg, ( &self.seg_cell_id, &self.seg_syn_count, &self.syn_presyn, &mut self.syn_perm, &self.seg_num_active_conn, &self.prev_active_bits, sp_active_mask, &self.col_predicted, &self.cell_seg_count, cfg_val, ), )?; } // 5. Punish. unsafe { punish_fn.clone().launch( learn_cfg, ( &self.seg_cell_id, &self.seg_syn_count, &self.syn_presyn, &mut self.syn_perm, &self.seg_num_active_pot, &self.prev_active_bits, sp_active_mask, &self.cell_seg_count, cfg_val, ), )?; } // 6. Grow. let grow_cfg = LaunchConfig { grid_dim: (n_cols as u32, 1, 1), block_dim: (32, 1, 1), shared_mem_bytes: 0, }; unsafe { grow_fn.clone().launch( grow_cfg, ( &mut self.seg_cell_id, &mut self.seg_syn_count, &mut self.syn_presyn, &mut self.syn_perm, &mut self.cell_seg_count, &self.burst_cols_flat, &self.burst_cols_count, &self.prev_winner_bits, &self.prev_active_bits, &self.col_best_match, cfg_val, ), )?; } } Ok(()) } }