icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
//! Numenta BAMI-spec Spatial Pooler.
//!
//! Implements:
//! - 2048 (configurable) mini-columns with proximal dendrites
//! - `potential_synapses` (default 40) synapses per column sampled from
//! `potential_radius` (default 1024) random input bits
//! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5
//! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008
//! - Global k-WTA inhibition (top `sparsity` fraction of columns)
//! - Boost factor with exponential duty-cycle tracking (Numenta formula)
//!
//! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017).
use rand::Rng;
use rand::SeedableRng;
use rand::seq::SliceRandom;
use rand_xoshiro::Xoshiro256PlusPlus;
/// A single proximal dendrite: a sparse set of potential synapses onto
/// specific input bit indices, with per-synapse permanence values.
#[derive(Clone)]
pub struct ProximalDendrite {
/// Indices into the input SDR. Length == potential_synapses.
pub inputs: Vec<u32>,
/// Permanence for each potential synapse (same length as `inputs`).
pub perms: Vec<f32>,
}
pub struct SpatialPoolerConfig {
pub input_bits: usize,
pub n_columns: usize,
/// Size of the random input sample per column.
pub potential_radius: usize,
/// Number of potential synapses per column's proximal dendrite.
pub potential_synapses: usize,
pub connected_threshold: f32,
pub syn_perm_active_inc: f32,
pub syn_perm_inactive_dec: f32,
/// Target fraction of columns active per step (e.g. 0.02 for 2%).
pub sparsity: f32,
/// Duty cycle EMA period.
pub duty_cycle_period: f32,
/// Boost strength. Set to 0.0 to disable boosting.
pub boost_strength: f32,
/// Initial permanence span around the connected threshold.
pub init_perm_span: f32,
}
impl Default for SpatialPoolerConfig {
fn default() -> Self {
Self {
input_bits: 16384,
n_columns: 2048,
potential_radius: 1024,
potential_synapses: 40,
connected_threshold: 0.5,
syn_perm_active_inc: 0.04,
syn_perm_inactive_dec: 0.008,
sparsity: 0.02,
duty_cycle_period: 1000.0,
boost_strength: 1.0,
init_perm_span: 0.1,
}
}
}
pub struct SpatialPooler {
pub cfg: SpatialPoolerConfig,
pub columns: Vec<ProximalDendrite>,
/// Exponential moving average of "column was active" per step.
pub active_duty_cycle: Vec<f32>,
/// Exponential moving average of "overlap exceeded threshold" per step.
pub overlap_duty_cycle: Vec<f32>,
/// Boost factor per column.
pub boost: Vec<f32>,
rng: Xoshiro256PlusPlus,
iter_count: u64,
}
impl SpatialPooler {
pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self {
assert!(cfg.input_bits >= cfg.potential_radius,
"input_bits ({}) must be >= potential_radius ({})",
cfg.input_bits, cfg.potential_radius);
assert!(cfg.potential_radius >= cfg.potential_synapses,
"potential_radius ({}) must be >= potential_synapses ({})",
cfg.potential_radius, cfg.potential_synapses);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
let mut columns = Vec::with_capacity(cfg.n_columns);
for _ in 0..cfg.n_columns {
// Sample `potential_radius` distinct input indices, then from those
// pick `potential_synapses` as the actual proximal synapses.
// Using partial Fisher-Yates via shuffle on a pool index range.
let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect();
// Efficient partial shuffle: swap the first `potential_radius`
// items with random items from the rest (Durstenfeld step).
for i in 0..cfg.potential_radius.min(pool.len()) {
let j = rng.gen_range(i..pool.len());
pool.swap(i, j);
}
let window = &mut pool[..cfg.potential_radius];
window.shuffle(&mut rng);
let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec();
inputs.sort_unstable();
let perms: Vec<f32> = (0..cfg.potential_synapses)
.map(|_| {
let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span);
(cfg.connected_threshold + delta).clamp(0.0, 1.0)
})
.collect();
columns.push(ProximalDendrite { inputs, perms });
}
let n = cfg.n_columns;
Self {
cfg,
columns,
active_duty_cycle: vec![0.0; n],
overlap_duty_cycle: vec![0.0; n],
boost: vec![1.0; n],
rng,
iter_count: 0,
}
}
/// Process one step: compute overlaps, inhibit, learn (if `learn`), update
/// duty cycles and boosts. Returns the set of active column indices.
pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> {
assert_eq!(input.len(), self.cfg.input_bits);
// 1) Overlap score per column (sum of CONNECTED synapses onto active inputs).
// Also track raw overlap for the overlap-duty-cycle.
let n = self.cfg.n_columns;
let mut overlaps: Vec<f32> = vec![0.0; n];
let mut raw_overlaps: Vec<u32> = vec![0; n];
for (ci, col) in self.columns.iter().enumerate() {
let mut s: u32 = 0;
for (syn_i, &inp) in col.inputs.iter().enumerate() {
if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold {
s += 1;
}
}
raw_overlaps[ci] = s;
overlaps[ci] = (s as f32) * self.boost[ci];
}
// 2) Global k-WTA inhibition. Select top-k columns by boosted overlap.
let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1);
let active: Vec<u32> = top_k(&overlaps, k);
// 3) Hebbian learning on active columns.
if learn {
for &ci in &active {
let col = &mut self.columns[ci as usize];
for (syn_i, &inp) in col.inputs.iter().enumerate() {
if input[inp as usize] {
col.perms[syn_i] =
(col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0);
} else {
col.perms[syn_i] =
(col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0);
}
}
}
}
// 4) Update duty cycles (EMA with period T -> alpha = 1/T).
let period = self.cfg.duty_cycle_period.max(1.0);
let alpha = 1.0 / period;
// Column is "overlapping enough" if raw overlap >= stimulus_threshold.
// Numenta uses min_overlap; we use 1 as a conservative floor.
let stimulus_threshold = 1.0_f32;
// Mark active columns.
let mut active_mask = vec![false; n];
for &ci in &active {
active_mask[ci as usize] = true;
}
for i in 0..n {
let active_sample = if active_mask[i] { 1.0 } else { 0.0 };
let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold {
1.0
} else {
0.0
};
self.active_duty_cycle[i] =
(1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample;
self.overlap_duty_cycle[i] =
(1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample;
}
// 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)).
// Under-used columns (duty < mean) get boost > 1.
if learn && self.cfg.boost_strength > 0.0 {
let mean_duty: f32 =
self.active_duty_cycle.iter().sum::<f32>() / (n as f32);
for i in 0..n {
self.boost[i] =
(-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp();
}
// 6) Permanence bump for chronically under-stimulated columns.
// If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood,
// bump all permanences by syn_perm_active_inc * 0.1.
// With global inhibition, "neighborhood" = all columns.
let max_overlap_duty = self
.overlap_duty_cycle
.iter()
.cloned()
.fold(0.0_f32, f32::max);
let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty;
if max_overlap_duty > 0.0 {
for i in 0..n {
if self.overlap_duty_cycle[i] < min_pct_overlap_duty {
for p in &mut self.columns[i].perms {
*p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0);
}
}
}
}
}
self.iter_count = self.iter_count.wrapping_add(1);
let _ = &mut self.rng; // suppress unused-mut when learn=false
active
}
}
/// Return the indices of the top-k values in `scores`.
/// Ties broken by index order. Output is sorted ascending.
fn top_k(scores: &[f32], k: usize) -> Vec<u32> {
if k == 0 {
return Vec::new();
}
let mut idx: Vec<u32> = (0..scores.len() as u32).collect();
// Partial sort: put top-k at the front by descending score.
// Use select_nth_unstable_by on (desc score, asc index).
idx.select_nth_unstable_by(k - 1, |&a, &b| {
let sa = scores[a as usize];
let sb = scores[b as usize];
// Reverse for descending.
match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) {
std::cmp::Ordering::Equal => a.cmp(&b),
ord => ord,
}
});
let mut winners: Vec<u32> = idx[..k].to_vec();
winners.sort_unstable();
winners
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
#[test]
fn sp_sparsity_exact_2pct() {
// BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41.
// The SP must produce *exactly* that count, no more, no less, and with
// no duplicate indices.
let cfg = SpatialPoolerConfig::default();
let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize;
assert!(expected_k > 0);
let input_bits = cfg.input_bits;
let mut sp = SpatialPooler::new(cfg, 42);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
for _ in 0..100 {
// 2% sparse random input SDR.
let on_bits = (0.02 * input_bits as f32) as usize;
let mut sdr = vec![false; input_bits];
for _ in 0..on_bits {
let i = rng.gen_range(0..input_bits);
sdr[i] = true;
}
let active = sp.compute(&sdr, true);
assert_eq!(
active.len(),
expected_k,
"SP must emit exactly {expected_k} active columns"
);
let mut a = active.clone();
a.sort_unstable();
a.dedup();
assert_eq!(a.len(), expected_k);
}
}
}