Spaces:
Runtime error
Runtime error
| //! Numenta BAMI-spec Spatial Pooler. | |
| //! | |
| //! Implements: | |
| //! - 2048 (configurable) mini-columns with proximal dendrites | |
| //! - `potential_synapses` (default 40) synapses per column sampled from | |
| //! `potential_radius` (default 1024) random input bits | |
| //! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5 | |
| //! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008 | |
| //! - Global k-WTA inhibition (top `sparsity` fraction of columns) | |
| //! - Boost factor with exponential duty-cycle tracking (Numenta formula) | |
| //! | |
| //! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017). | |
| use rand::Rng; | |
| use rand::SeedableRng; | |
| use rand::seq::SliceRandom; | |
| use rand_xoshiro::Xoshiro256PlusPlus; | |
| /// A single proximal dendrite: a sparse set of potential synapses onto | |
| /// specific input bit indices, with per-synapse permanence values. | |
| pub struct ProximalDendrite { | |
| /// Indices into the input SDR. Length == potential_synapses. | |
| pub inputs: Vec<u32>, | |
| /// Permanence for each potential synapse (same length as `inputs`). | |
| pub perms: Vec<f32>, | |
| } | |
| pub struct SpatialPoolerConfig { | |
| pub input_bits: usize, | |
| pub n_columns: usize, | |
| /// Size of the random input sample per column. | |
| pub potential_radius: usize, | |
| /// Number of potential synapses per column's proximal dendrite. | |
| pub potential_synapses: usize, | |
| pub connected_threshold: f32, | |
| pub syn_perm_active_inc: f32, | |
| pub syn_perm_inactive_dec: f32, | |
| /// Target fraction of columns active per step (e.g. 0.02 for 2%). | |
| pub sparsity: f32, | |
| /// Duty cycle EMA period. | |
| pub duty_cycle_period: f32, | |
| /// Boost strength. Set to 0.0 to disable boosting. | |
| pub boost_strength: f32, | |
| /// Initial permanence span around the connected threshold. | |
| pub init_perm_span: f32, | |
| } | |
| impl Default for SpatialPoolerConfig { | |
| fn default() -> Self { | |
| Self { | |
| input_bits: 16384, | |
| n_columns: 2048, | |
| potential_radius: 1024, | |
| potential_synapses: 40, | |
| connected_threshold: 0.5, | |
| syn_perm_active_inc: 0.04, | |
| syn_perm_inactive_dec: 0.008, | |
| sparsity: 0.02, | |
| duty_cycle_period: 1000.0, | |
| boost_strength: 1.0, | |
| init_perm_span: 0.1, | |
| } | |
| } | |
| } | |
| pub struct SpatialPooler { | |
| pub cfg: SpatialPoolerConfig, | |
| pub columns: Vec<ProximalDendrite>, | |
| /// Exponential moving average of "column was active" per step. | |
| pub active_duty_cycle: Vec<f32>, | |
| /// Exponential moving average of "overlap exceeded threshold" per step. | |
| pub overlap_duty_cycle: Vec<f32>, | |
| /// Boost factor per column. | |
| pub boost: Vec<f32>, | |
| rng: Xoshiro256PlusPlus, | |
| iter_count: u64, | |
| } | |
| impl SpatialPooler { | |
| pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self { | |
| assert!(cfg.input_bits >= cfg.potential_radius, | |
| "input_bits ({}) must be >= potential_radius ({})", | |
| cfg.input_bits, cfg.potential_radius); | |
| assert!(cfg.potential_radius >= cfg.potential_synapses, | |
| "potential_radius ({}) must be >= potential_synapses ({})", | |
| cfg.potential_radius, cfg.potential_synapses); | |
| let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); | |
| let mut columns = Vec::with_capacity(cfg.n_columns); | |
| for _ in 0..cfg.n_columns { | |
| // Sample `potential_radius` distinct input indices, then from those | |
| // pick `potential_synapses` as the actual proximal synapses. | |
| // Using partial Fisher-Yates via shuffle on a pool index range. | |
| let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect(); | |
| // Efficient partial shuffle: swap the first `potential_radius` | |
| // items with random items from the rest (Durstenfeld step). | |
| for i in 0..cfg.potential_radius.min(pool.len()) { | |
| let j = rng.gen_range(i..pool.len()); | |
| pool.swap(i, j); | |
| } | |
| let window = &mut pool[..cfg.potential_radius]; | |
| window.shuffle(&mut rng); | |
| let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec(); | |
| inputs.sort_unstable(); | |
| let perms: Vec<f32> = (0..cfg.potential_synapses) | |
| .map(|_| { | |
| let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span); | |
| (cfg.connected_threshold + delta).clamp(0.0, 1.0) | |
| }) | |
| .collect(); | |
| columns.push(ProximalDendrite { inputs, perms }); | |
| } | |
| let n = cfg.n_columns; | |
| Self { | |
| cfg, | |
| columns, | |
| active_duty_cycle: vec![0.0; n], | |
| overlap_duty_cycle: vec![0.0; n], | |
| boost: vec![1.0; n], | |
| rng, | |
| iter_count: 0, | |
| } | |
| } | |
| /// Process one step: compute overlaps, inhibit, learn (if `learn`), update | |
| /// duty cycles and boosts. Returns the set of active column indices. | |
| pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> { | |
| assert_eq!(input.len(), self.cfg.input_bits); | |
| // 1) Overlap score per column (sum of CONNECTED synapses onto active inputs). | |
| // Also track raw overlap for the overlap-duty-cycle. | |
| let n = self.cfg.n_columns; | |
| let mut overlaps: Vec<f32> = vec![0.0; n]; | |
| let mut raw_overlaps: Vec<u32> = vec![0; n]; | |
| for (ci, col) in self.columns.iter().enumerate() { | |
| let mut s: u32 = 0; | |
| for (syn_i, &inp) in col.inputs.iter().enumerate() { | |
| if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold { | |
| s += 1; | |
| } | |
| } | |
| raw_overlaps[ci] = s; | |
| overlaps[ci] = (s as f32) * self.boost[ci]; | |
| } | |
| // 2) Global k-WTA inhibition. Select top-k columns by boosted overlap. | |
| let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1); | |
| let active: Vec<u32> = top_k(&overlaps, k); | |
| // 3) Hebbian learning on active columns. | |
| if learn { | |
| for &ci in &active { | |
| let col = &mut self.columns[ci as usize]; | |
| for (syn_i, &inp) in col.inputs.iter().enumerate() { | |
| if input[inp as usize] { | |
| col.perms[syn_i] = | |
| (col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0); | |
| } else { | |
| col.perms[syn_i] = | |
| (col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0); | |
| } | |
| } | |
| } | |
| } | |
| // 4) Update duty cycles (EMA with period T -> alpha = 1/T). | |
| let period = self.cfg.duty_cycle_period.max(1.0); | |
| let alpha = 1.0 / period; | |
| // Column is "overlapping enough" if raw overlap >= stimulus_threshold. | |
| // Numenta uses min_overlap; we use 1 as a conservative floor. | |
| let stimulus_threshold = 1.0_f32; | |
| // Mark active columns. | |
| let mut active_mask = vec![false; n]; | |
| for &ci in &active { | |
| active_mask[ci as usize] = true; | |
| } | |
| for i in 0..n { | |
| let active_sample = if active_mask[i] { 1.0 } else { 0.0 }; | |
| let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold { | |
| 1.0 | |
| } else { | |
| 0.0 | |
| }; | |
| self.active_duty_cycle[i] = | |
| (1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample; | |
| self.overlap_duty_cycle[i] = | |
| (1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample; | |
| } | |
| // 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)). | |
| // Under-used columns (duty < mean) get boost > 1. | |
| if learn && self.cfg.boost_strength > 0.0 { | |
| let mean_duty: f32 = | |
| self.active_duty_cycle.iter().sum::<f32>() / (n as f32); | |
| for i in 0..n { | |
| self.boost[i] = | |
| (-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp(); | |
| } | |
| // 6) Permanence bump for chronically under-stimulated columns. | |
| // If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood, | |
| // bump all permanences by syn_perm_active_inc * 0.1. | |
| // With global inhibition, "neighborhood" = all columns. | |
| let max_overlap_duty = self | |
| .overlap_duty_cycle | |
| .iter() | |
| .cloned() | |
| .fold(0.0_f32, f32::max); | |
| let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty; | |
| if max_overlap_duty > 0.0 { | |
| for i in 0..n { | |
| if self.overlap_duty_cycle[i] < min_pct_overlap_duty { | |
| for p in &mut self.columns[i].perms { | |
| *p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| self.iter_count = self.iter_count.wrapping_add(1); | |
| let _ = &mut self.rng; // suppress unused-mut when learn=false | |
| active | |
| } | |
| } | |
| /// Return the indices of the top-k values in `scores`. | |
| /// Ties broken by index order. Output is sorted ascending. | |
| fn top_k(scores: &[f32], k: usize) -> Vec<u32> { | |
| if k == 0 { | |
| return Vec::new(); | |
| } | |
| let mut idx: Vec<u32> = (0..scores.len() as u32).collect(); | |
| // Partial sort: put top-k at the front by descending score. | |
| // Use select_nth_unstable_by on (desc score, asc index). | |
| idx.select_nth_unstable_by(k - 1, |&a, &b| { | |
| let sa = scores[a as usize]; | |
| let sb = scores[b as usize]; | |
| // Reverse for descending. | |
| match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) { | |
| std::cmp::Ordering::Equal => a.cmp(&b), | |
| ord => ord, | |
| } | |
| }); | |
| let mut winners: Vec<u32> = idx[..k].to_vec(); | |
| winners.sort_unstable(); | |
| winners | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Tests | |
| // --------------------------------------------------------------------------- | |
| mod tests { | |
| use super::*; | |
| use rand::Rng; | |
| use rand::SeedableRng; | |
| use rand_xoshiro::Xoshiro256PlusPlus; | |
| fn sp_sparsity_exact_2pct() { | |
| // BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41. | |
| // The SP must produce *exactly* that count, no more, no less, and with | |
| // no duplicate indices. | |
| let cfg = SpatialPoolerConfig::default(); | |
| let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize; | |
| assert!(expected_k > 0); | |
| let input_bits = cfg.input_bits; | |
| let mut sp = SpatialPooler::new(cfg, 42); | |
| let mut rng = Xoshiro256PlusPlus::seed_from_u64(7); | |
| for _ in 0..100 { | |
| // 2% sparse random input SDR. | |
| let on_bits = (0.02 * input_bits as f32) as usize; | |
| let mut sdr = vec![false; input_bits]; | |
| for _ in 0..on_bits { | |
| let i = rng.gen_range(0..input_bits); | |
| sdr[i] = true; | |
| } | |
| let active = sp.compute(&sdr, true); | |
| assert_eq!( | |
| active.len(), | |
| expected_k, | |
| "SP must emit exactly {expected_k} active columns" | |
| ); | |
| let mut a = active.clone(); | |
| a.sort_unstable(); | |
| a.dedup(); | |
| assert_eq!(a.len(), expected_k); | |
| } | |
| } | |
| } | |