File size: 11,602 Bytes
1c59946 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 | //! Numenta BAMI-spec Spatial Pooler.
//!
//! Implements:
//! - 2048 (configurable) mini-columns with proximal dendrites
//! - `potential_synapses` (default 40) synapses per column sampled from
//! `potential_radius` (default 1024) random input bits
//! - Permanence in [0.0, 1.0] (f32), connected_threshold = 0.5
//! - syn_perm_active_inc = +0.04, syn_perm_inactive_dec = -0.008
//! - Global k-WTA inhibition (top `sparsity` fraction of columns)
//! - Boost factor with exponential duty-cycle tracking (Numenta formula)
//!
//! Reference: BAMI "Spatial Pooling Algorithm Details" (Numenta, 2017).
use rand::Rng;
use rand::SeedableRng;
use rand::seq::SliceRandom;
use rand_xoshiro::Xoshiro256PlusPlus;
/// A single proximal dendrite: a sparse set of potential synapses onto
/// specific input bit indices, with per-synapse permanence values.
#[derive(Clone)]
pub struct ProximalDendrite {
/// Indices into the input SDR. Length == potential_synapses.
pub inputs: Vec<u32>,
/// Permanence for each potential synapse (same length as `inputs`).
pub perms: Vec<f32>,
}
pub struct SpatialPoolerConfig {
pub input_bits: usize,
pub n_columns: usize,
/// Size of the random input sample per column.
pub potential_radius: usize,
/// Number of potential synapses per column's proximal dendrite.
pub potential_synapses: usize,
pub connected_threshold: f32,
pub syn_perm_active_inc: f32,
pub syn_perm_inactive_dec: f32,
/// Target fraction of columns active per step (e.g. 0.02 for 2%).
pub sparsity: f32,
/// Duty cycle EMA period.
pub duty_cycle_period: f32,
/// Boost strength. Set to 0.0 to disable boosting.
pub boost_strength: f32,
/// Initial permanence span around the connected threshold.
pub init_perm_span: f32,
}
impl Default for SpatialPoolerConfig {
fn default() -> Self {
Self {
input_bits: 16384,
n_columns: 2048,
potential_radius: 1024,
potential_synapses: 40,
connected_threshold: 0.5,
syn_perm_active_inc: 0.04,
syn_perm_inactive_dec: 0.008,
sparsity: 0.02,
duty_cycle_period: 1000.0,
boost_strength: 1.0,
init_perm_span: 0.1,
}
}
}
pub struct SpatialPooler {
pub cfg: SpatialPoolerConfig,
pub columns: Vec<ProximalDendrite>,
/// Exponential moving average of "column was active" per step.
pub active_duty_cycle: Vec<f32>,
/// Exponential moving average of "overlap exceeded threshold" per step.
pub overlap_duty_cycle: Vec<f32>,
/// Boost factor per column.
pub boost: Vec<f32>,
rng: Xoshiro256PlusPlus,
iter_count: u64,
}
impl SpatialPooler {
pub fn new(cfg: SpatialPoolerConfig, seed: u64) -> Self {
assert!(cfg.input_bits >= cfg.potential_radius,
"input_bits ({}) must be >= potential_radius ({})",
cfg.input_bits, cfg.potential_radius);
assert!(cfg.potential_radius >= cfg.potential_synapses,
"potential_radius ({}) must be >= potential_synapses ({})",
cfg.potential_radius, cfg.potential_synapses);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
let mut columns = Vec::with_capacity(cfg.n_columns);
for _ in 0..cfg.n_columns {
// Sample `potential_radius` distinct input indices, then from those
// pick `potential_synapses` as the actual proximal synapses.
// Using partial Fisher-Yates via shuffle on a pool index range.
let mut pool: Vec<u32> = (0..cfg.input_bits as u32).collect();
// Efficient partial shuffle: swap the first `potential_radius`
// items with random items from the rest (Durstenfeld step).
for i in 0..cfg.potential_radius.min(pool.len()) {
let j = rng.gen_range(i..pool.len());
pool.swap(i, j);
}
let window = &mut pool[..cfg.potential_radius];
window.shuffle(&mut rng);
let mut inputs: Vec<u32> = window[..cfg.potential_synapses].to_vec();
inputs.sort_unstable();
let perms: Vec<f32> = (0..cfg.potential_synapses)
.map(|_| {
let delta: f32 = rng.gen_range(-cfg.init_perm_span..cfg.init_perm_span);
(cfg.connected_threshold + delta).clamp(0.0, 1.0)
})
.collect();
columns.push(ProximalDendrite { inputs, perms });
}
let n = cfg.n_columns;
Self {
cfg,
columns,
active_duty_cycle: vec![0.0; n],
overlap_duty_cycle: vec![0.0; n],
boost: vec![1.0; n],
rng,
iter_count: 0,
}
}
/// Process one step: compute overlaps, inhibit, learn (if `learn`), update
/// duty cycles and boosts. Returns the set of active column indices.
pub fn compute(&mut self, input: &[bool], learn: bool) -> Vec<u32> {
assert_eq!(input.len(), self.cfg.input_bits);
// 1) Overlap score per column (sum of CONNECTED synapses onto active inputs).
// Also track raw overlap for the overlap-duty-cycle.
let n = self.cfg.n_columns;
let mut overlaps: Vec<f32> = vec![0.0; n];
let mut raw_overlaps: Vec<u32> = vec![0; n];
for (ci, col) in self.columns.iter().enumerate() {
let mut s: u32 = 0;
for (syn_i, &inp) in col.inputs.iter().enumerate() {
if input[inp as usize] && col.perms[syn_i] >= self.cfg.connected_threshold {
s += 1;
}
}
raw_overlaps[ci] = s;
overlaps[ci] = (s as f32) * self.boost[ci];
}
// 2) Global k-WTA inhibition. Select top-k columns by boosted overlap.
let k = ((self.cfg.sparsity * n as f32).round() as usize).max(1);
let active: Vec<u32> = top_k(&overlaps, k);
// 3) Hebbian learning on active columns.
if learn {
for &ci in &active {
let col = &mut self.columns[ci as usize];
for (syn_i, &inp) in col.inputs.iter().enumerate() {
if input[inp as usize] {
col.perms[syn_i] =
(col.perms[syn_i] + self.cfg.syn_perm_active_inc).min(1.0);
} else {
col.perms[syn_i] =
(col.perms[syn_i] - self.cfg.syn_perm_inactive_dec).max(0.0);
}
}
}
}
// 4) Update duty cycles (EMA with period T -> alpha = 1/T).
let period = self.cfg.duty_cycle_period.max(1.0);
let alpha = 1.0 / period;
// Column is "overlapping enough" if raw overlap >= stimulus_threshold.
// Numenta uses min_overlap; we use 1 as a conservative floor.
let stimulus_threshold = 1.0_f32;
// Mark active columns.
let mut active_mask = vec![false; n];
for &ci in &active {
active_mask[ci as usize] = true;
}
for i in 0..n {
let active_sample = if active_mask[i] { 1.0 } else { 0.0 };
let overlap_sample = if (raw_overlaps[i] as f32) >= stimulus_threshold {
1.0
} else {
0.0
};
self.active_duty_cycle[i] =
(1.0 - alpha) * self.active_duty_cycle[i] + alpha * active_sample;
self.overlap_duty_cycle[i] =
(1.0 - alpha) * self.overlap_duty_cycle[i] + alpha * overlap_sample;
}
// 5) Boost factor: b_i = exp(-boost_strength * (duty_i - mean_duty)).
// Under-used columns (duty < mean) get boost > 1.
if learn && self.cfg.boost_strength > 0.0 {
let mean_duty: f32 =
self.active_duty_cycle.iter().sum::<f32>() / (n as f32);
for i in 0..n {
self.boost[i] =
(-self.cfg.boost_strength * (self.active_duty_cycle[i] - mean_duty)).exp();
}
// 6) Permanence bump for chronically under-stimulated columns.
// If overlap_duty_cycle[i] < min_pct_overlap * max_duty_in_neighborhood,
// bump all permanences by syn_perm_active_inc * 0.1.
// With global inhibition, "neighborhood" = all columns.
let max_overlap_duty = self
.overlap_duty_cycle
.iter()
.cloned()
.fold(0.0_f32, f32::max);
let min_pct_overlap_duty = 0.001_f32 * max_overlap_duty;
if max_overlap_duty > 0.0 {
for i in 0..n {
if self.overlap_duty_cycle[i] < min_pct_overlap_duty {
for p in &mut self.columns[i].perms {
*p = (*p + self.cfg.syn_perm_active_inc * 0.1).min(1.0);
}
}
}
}
}
self.iter_count = self.iter_count.wrapping_add(1);
let _ = &mut self.rng; // suppress unused-mut when learn=false
active
}
}
/// Return the indices of the top-k values in `scores`.
/// Ties broken by index order. Output is sorted ascending.
fn top_k(scores: &[f32], k: usize) -> Vec<u32> {
if k == 0 {
return Vec::new();
}
let mut idx: Vec<u32> = (0..scores.len() as u32).collect();
// Partial sort: put top-k at the front by descending score.
// Use select_nth_unstable_by on (desc score, asc index).
idx.select_nth_unstable_by(k - 1, |&a, &b| {
let sa = scores[a as usize];
let sb = scores[b as usize];
// Reverse for descending.
match sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal) {
std::cmp::Ordering::Equal => a.cmp(&b),
ord => ord,
}
});
let mut winners: Vec<u32> = idx[..k].to_vec();
winners.sort_unstable();
winners
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
#[test]
fn sp_sparsity_exact_2pct() {
// BAMI says "top ~2%"; with 2048 columns that's round(0.02*2048) = 41.
// The SP must produce *exactly* that count, no more, no less, and with
// no duplicate indices.
let cfg = SpatialPoolerConfig::default();
let expected_k = (cfg.sparsity * cfg.n_columns as f32).round() as usize;
assert!(expected_k > 0);
let input_bits = cfg.input_bits;
let mut sp = SpatialPooler::new(cfg, 42);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(7);
for _ in 0..100 {
// 2% sparse random input SDR.
let on_bits = (0.02 * input_bits as f32) as usize;
let mut sdr = vec![false; input_bits];
for _ in 0..on_bits {
let i = rng.gen_range(0..input_bits);
sdr[i] = true;
}
let active = sp.compute(&sdr, true);
assert_eq!(
active.len(),
expected_k,
"SP must emit exactly {expected_k} active columns"
);
let mut a = active.clone();
a.sort_unstable();
a.dedup();
assert_eq!(a.len(), expected_k);
}
}
}
|