//! Numenta BAMI-spec Spatial Pooler. //! //! Implements: //! - 2048 (configurable) mini-columns with proximal dendrites //! - `potential_synapses` (default 40) synapses per column sampled from //! `potential_radius` (default 1024) random input bits //! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5 //! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008 //! - Global k-WTA inhibition (top `sparsity` fraction of columns) //! - Boost factor with exponential duty-cycle tracking (Numenta formula) //! //! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017). use rand::Rng; use rand::SeedableRng; use rand::seq::SliceRandom; use rand_xoshiro::Xoshiro256PlusPlus; /// A single proximal dendrite: a sparse set of potential synapses onto /// specific input bit indices, with per-synapse permanence values. #[derive(Clone)] pub struct ProximalDendrite { /// Indices into the input SDR. Length == potential_synapses. pub inputs: Vec, /// Permanence for each potential synapse (same length as `inputs`). pub perms: Vec, } pub struct SpatialPoolerConfig { pub input_bits: usize, pub n_columns: usize, /// Size of the random input sample per column. pub potential_radius: usize, /// Number of potential synapses per column's proximal dendrite. pub potential_synapses: usize, pub connected_threshold: f32, pub syn_perm_active_inc: f32, pub syn_perm_inactive_dec: f32, /// Target fraction of columns active per step (e.g. 0.02 for 2%). pub sparsity: f32, /// Duty cycle EMA period. pub duty_cycle_period: f32, /// Boost strength. Set to 0.0 to disable boosting. pub boost_strength: f32, /// Initial permanence span around the connected threshold. pub init_perm_span: f32, } impl Default for SpatialPoolerConfig { fn default() -> Self { Self { input_bits: 16384, n_columns: 2048, potential_radius: 1024, potential_synapses: 40, connected_threshold: 0.5, syn_perm_active_inc: 0.04, syn_perm_inactive_dec: 0.008, sparsity: 0.02, duty_cycle_period: 1000.0, boost_strength: 1.0, init_perm_span: 0.1, } } } pub struct SpatialPooler { pub cfg: SpatialPoolerConfig, pub columns: Vec, /// Exponential moving average of "column was active" per step. pub active_duty_cycle: Vec, /// Exponential moving average of "overlap exceeded threshold" per step. pub overlap_duty_cycle: Vec, /// Boost factor per column. pub boost: Vec, rng: Xoshiro256PlusPlus, iter_count: u64, } impl SpatialPooler { pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self { assert!(cfg.input_bits >= cfg.potential_radius, "input_bits ({}) must be >= potential_radius ({})", cfg.input_bits, cfg.potential_radius); assert!(cfg.potential_radius >= cfg.potential_synapses, "potential_radius ({}) must be >= potential_synapses ({})", cfg.potential_radius, cfg.potential_synapses); let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); let mut columns = Vec::with_capacity(cfg.n_columns); for _ in 0..cfg.n_columns { // Sample `potential_radius` distinct input indices, then from those // pick `potential_synapses` as the actual proximal synapses. // Using partial Fisher-Yates via shuffle on a pool index range. let mut pool: Vec = (0..cfg.input_bits as u32).collect(); // Efficient partial shuffle: swap the first `potential_radius` // items with random items from the rest (Durstenfeld step). for i in 0..cfg.potential_radius.min(pool.len()) { let j = rng.gen_range(i..pool.len()); pool.swap(i, j); } let window = &mut pool[..cfg.potential_radius]; window.shuffle(&mut rng); let mut inputs: Vec = window[..cfg.potential_synapses].to_vec(); inputs.sort_unstable(); let perms: Vec = (0..cfg.potential_synapses) .map(|_| { let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span); (cfg.connected_threshold + delta).clamp(0.0, 1.0) }) .collect(); columns.push(ProximalDendrite { inputs, perms }); } let n = cfg.n_columns; Self { cfg, columns, active_duty_cycle: vec![0.0; n], overlap_duty_cycle: vec![0.0; n], boost: vec![1.0; n], rng, iter_count: 0, } } /// Process one step: compute overlaps, inhibit, learn (if `learn`), update /// duty cycles and boosts. Returns the set of active column indices. pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec { assert_eq!(input.len(), self.cfg.input_bits); // 1) Overlap score per column (sum of CONNECTED synapses onto active inputs). // Also track raw overlap for the overlap-duty-cycle. let n = self.cfg.n_columns; let mut overlaps: Vec = vec![0.0; n]; let mut raw_overlaps: Vec = vec![0; n]; for (ci, col) in self.columns.iter().enumerate() { let mut s: u32 = 0; for (syn_i, &inp) in col.inputs.iter().enumerate() { if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold { s += 1; } } raw_overlaps[ci] = s; overlaps[ci] = (s as f32) * self.boost[ci]; } // 2) Global k-WTA inhibition. Select top-k columns by boosted overlap. let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1); let active: Vec = top_k(&overlaps, k); // 3) Hebbian learning on active columns. if learn { for &ci in &active { let col = &mut self.columns[ci as usize]; for (syn_i, &inp) in col.inputs.iter().enumerate() { if input[inp as usize] { col.perms[syn_i] = (col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0); } else { col.perms[syn_i] = (col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0); } } } } // 4) Update duty cycles (EMA with period T -> alpha = 1/T). let period = self.cfg.duty_cycle_period.max(1.0); let alpha = 1.0 / period; // Column is "overlapping enough" if raw overlap >= stimulus_threshold. // Numenta uses min_overlap; we use 1 as a conservative floor. let stimulus_threshold = 1.0_f32; // Mark active columns. let mut active_mask = vec![false; n]; for &ci in &active { active_mask[ci as usize] = true; } for i in 0..n { let active_sample = if active_mask[i] { 1.0 } else { 0.0 }; let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold { 1.0 } else { 0.0 }; self.active_duty_cycle[i] = (1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample; self.overlap_duty_cycle[i] = (1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample; } // 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)). // Under-used columns (duty < mean) get boost > 1. if learn && self.cfg.boost_strength > 0.0 { let mean_duty: f32 = self.active_duty_cycle.iter().sum::() / (n as f32); for i in 0..n { self.boost[i] = (-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp(); } // 6) Permanence bump for chronically under-stimulated columns. // If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood, // bump all permanences by syn_perm_active_inc * 0.1. // With global inhibition, "neighborhood" = all columns. let max_overlap_duty = self .overlap_duty_cycle .iter() .cloned() .fold(0.0_f32, f32::max); let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty; if max_overlap_duty > 0.0 { for i in 0..n { if self.overlap_duty_cycle[i] < min_pct_overlap_duty { for p in &mut self.columns[i].perms { *p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0); } } } } } self.iter_count = self.iter_count.wrapping_add(1); let _ = &mut self.rng; // suppress unused-mut when learn=false active } } /// Return the indices of the top-k values in `scores`. /// Ties broken by index order. Output is sorted ascending. fn top_k(scores: &[f32], k: usize) -> Vec { if k == 0 { return Vec::new(); } let mut idx: Vec = (0..scores.len() as u32).collect(); // Partial sort: put top-k at the front by descending score. // Use select_nth_unstable_by on (desc score, asc index). idx.select_nth_unstable_by(k - 1, |&a, &b| { let sa = scores[a as usize]; let sb = scores[b as usize]; // Reverse for descending. match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) { std::cmp::Ordering::Equal => a.cmp(&b), ord => ord, } }); let mut winners: Vec = idx[..k].to_vec(); winners.sort_unstable(); winners } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; use rand::Rng; use rand::SeedableRng; use rand_xoshiro::Xoshiro256PlusPlus; #[test] fn sp_sparsity_exact_2pct() { // BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41. // The SP must produce *exactly* that count, no more, no less, and with // no duplicate indices. let cfg = SpatialPoolerConfig::default(); let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize; assert!(expected_k > 0); let input_bits = cfg.input_bits; let mut sp = SpatialPooler::new(cfg, 42); let mut rng = Xoshiro256PlusPlus::seed_from_u64(7); for _ in 0..100 { // 2% sparse random input SDR. let on_bits = (0.02 * input_bits as f32) as usize; let mut sdr = vec![false; input_bits]; for _ in 0..on_bits { let i = rng.gen_range(0..input_bits); sdr[i] = true; } let active = sp.compute(&sdr, true); assert_eq!( active.len(), expected_k, "SP must emit exactly {expected_k} active columns" ); let mut a = active.clone(); a.sort_unstable(); a.dedup(); assert_eq!(a.len(), expected_k); } } }