| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use rand::Rng; |
| use rand::SeedableRng; |
| use rand::seq::SliceRandom; |
| use rand_xoshiro::Xoshiro256PlusPlus; |
|
|
| |
| |
| #[derive(Clone)] |
| pub struct ProximalDendrite { |
| |
| pub inputs: Vec<u32>, |
| |
| pub perms: Vec<f32>, |
| } |
|
|
| pub struct SpatialPoolerConfig { |
| pub input_bits: usize, |
| pub n_columns: usize, |
| |
| pub potential_radius: usize, |
| |
| pub potential_synapses: usize, |
| pub connected_threshold: f32, |
| pub syn_perm_active_inc: f32, |
| pub syn_perm_inactive_dec: f32, |
| |
| pub sparsity: f32, |
| |
| pub duty_cycle_period: f32, |
| |
| pub boost_strength: f32, |
| |
| pub init_perm_span: f32, |
| } |
|
|
| impl Default for SpatialPoolerConfig { |
| fn default() -> Self { |
| Self { |
| input_bits: 16384, |
| n_columns: 2048, |
| potential_radius: 1024, |
| potential_synapses: 40, |
| connected_threshold: 0.5, |
| syn_perm_active_inc: 0.04, |
| syn_perm_inactive_dec: 0.008, |
| sparsity: 0.02, |
| duty_cycle_period: 1000.0, |
| boost_strength: 1.0, |
| init_perm_span: 0.1, |
| } |
| } |
| } |
|
|
| pub struct SpatialPooler { |
| pub cfg: SpatialPoolerConfig, |
| pub columns: Vec<ProximalDendrite>, |
| |
| pub active_duty_cycle: Vec<f32>, |
| |
| pub overlap_duty_cycle: Vec<f32>, |
| |
| pub boost: Vec<f32>, |
| rng: Xoshiro256PlusPlus, |
| iter_count: u64, |
| } |
|
|
| impl SpatialPooler { |
| pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self { |
| assert!(cfg.input_bits >= cfg.potential_radius, |
| "input_bits ({}) must be >= potential_radius ({})", |
| cfg.input_bits, cfg.potential_radius); |
| assert!(cfg.potential_radius >= cfg.potential_synapses, |
| "potential_radius ({}) must be >= potential_synapses ({})", |
| cfg.potential_radius, cfg.potential_synapses); |
|
|
| let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); |
|
|
| let mut columns = Vec::with_capacity(cfg.n_columns); |
| for _ in 0..cfg.n_columns { |
| |
| |
| |
| let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect(); |
| |
| |
| for i in 0..cfg.potential_radius.min(pool.len()) { |
| let j = rng.gen_range(i..pool.len()); |
| pool.swap(i, j); |
| } |
| let window = &mut pool[..cfg.potential_radius]; |
| window.shuffle(&mut rng); |
| let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec(); |
| inputs.sort_unstable(); |
|
|
| let perms: Vec<f32> = (0..cfg.potential_synapses) |
| .map(|_| { |
| let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span); |
| (cfg.connected_threshold + delta).clamp(0.0, 1.0) |
| }) |
| .collect(); |
|
|
| columns.push(ProximalDendrite { inputs, perms }); |
| } |
|
|
| let n = cfg.n_columns; |
| Self { |
| cfg, |
| columns, |
| active_duty_cycle: vec![0.0; n], |
| overlap_duty_cycle: vec![0.0; n], |
| boost: vec![1.0; n], |
| rng, |
| iter_count: 0, |
| } |
| } |
|
|
| |
| |
| pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> { |
| assert_eq!(input.len(), self.cfg.input_bits); |
|
|
| |
| |
| let n = self.cfg.n_columns; |
| let mut overlaps: Vec<f32> = vec![0.0; n]; |
| let mut raw_overlaps: Vec<u32> = vec![0; n]; |
|
|
| for (ci, col) in self.columns.iter().enumerate() { |
| let mut s: u32 = 0; |
| for (syn_i, &inp) in col.inputs.iter().enumerate() { |
| if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold { |
| s += 1; |
| } |
| } |
| raw_overlaps[ci] = s; |
| overlaps[ci] = (s as f32) * self.boost[ci]; |
| } |
|
|
| |
| let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1); |
| let active: Vec<u32> = top_k(&overlaps, k); |
|
|
| |
| if learn { |
| for &ci in &active { |
| let col = &mut self.columns[ci as usize]; |
| for (syn_i, &inp) in col.inputs.iter().enumerate() { |
| if input[inp as usize] { |
| col.perms[syn_i] = |
| (col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0); |
| } else { |
| col.perms[syn_i] = |
| (col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0); |
| } |
| } |
| } |
| } |
|
|
| |
| let period = self.cfg.duty_cycle_period.max(1.0); |
| let alpha = 1.0 / period; |
| |
| |
| let stimulus_threshold = 1.0_f32; |
|
|
| |
| let mut active_mask = vec![false; n]; |
| for &ci in &active { |
| active_mask[ci as usize] = true; |
| } |
|
|
| for i in 0..n { |
| let active_sample = if active_mask[i] { 1.0 } else { 0.0 }; |
| let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold { |
| 1.0 |
| } else { |
| 0.0 |
| }; |
| self.active_duty_cycle[i] = |
| (1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample; |
| self.overlap_duty_cycle[i] = |
| (1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample; |
| } |
|
|
| |
| |
| if learn && self.cfg.boost_strength > 0.0 { |
| let mean_duty: f32 = |
| self.active_duty_cycle.iter().sum::<f32>() / (n as f32); |
| for i in 0..n { |
| self.boost[i] = |
| (-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp(); |
| } |
|
|
| |
| |
| |
| |
| let max_overlap_duty = self |
| .overlap_duty_cycle |
| .iter() |
| .cloned() |
| .fold(0.0_f32, f32::max); |
| let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty; |
| if max_overlap_duty > 0.0 { |
| for i in 0..n { |
| if self.overlap_duty_cycle[i] < min_pct_overlap_duty { |
| for p in &mut self.columns[i].perms { |
| *p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0); |
| } |
| } |
| } |
| } |
| } |
|
|
| self.iter_count = self.iter_count.wrapping_add(1); |
| let _ = &mut self.rng; |
| active |
| } |
| } |
|
|
| |
| |
| fn top_k(scores: &[f32], k: usize) -> Vec<u32> { |
| if k == 0 { |
| return Vec::new(); |
| } |
| let mut idx: Vec<u32> = (0..scores.len() as u32).collect(); |
| |
| |
| idx.select_nth_unstable_by(k - 1, |&a, &b| { |
| let sa = scores[a as usize]; |
| let sb = scores[b as usize]; |
| |
| match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) { |
| std::cmp::Ordering::Equal => a.cmp(&b), |
| ord => ord, |
| } |
| }); |
| let mut winners: Vec<u32> = idx[..k].to_vec(); |
| winners.sort_unstable(); |
| winners |
| } |
|
|
| |
| |
| |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use rand::Rng; |
| use rand::SeedableRng; |
| use rand_xoshiro::Xoshiro256PlusPlus; |
|
|
| #[test] |
| fn sp_sparsity_exact_2pct() { |
| |
| |
| |
| let cfg = SpatialPoolerConfig::default(); |
| let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize; |
| assert!(expected_k > 0); |
|
|
| let input_bits = cfg.input_bits; |
| let mut sp = SpatialPooler::new(cfg, 42); |
| let mut rng = Xoshiro256PlusPlus::seed_from_u64(7); |
|
|
| for _ in 0..100 { |
| |
| let on_bits = (0.02 * input_bits as f32) as usize; |
| let mut sdr = vec![false; input_bits]; |
| for _ in 0..on_bits { |
| let i = rng.gen_range(0..input_bits); |
| sdr[i] = true; |
| } |
| let active = sp.compute(&sdr, true); |
| assert_eq!( |
| active.len(), |
| expected_k, |
| "SP must emit exactly {expected_k} active columns" |
| ); |
| let mut a = active.clone(); |
| a.sort_unstable(); |
| a.dedup(); |
| assert_eq!(a.len(), expected_k); |
| } |
| } |
| } |
|
|