icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
//! Numenta BAMI-spec Temporal Memory.
//!
//! Key parameters (Numenta defaults):
//! - cells_per_column = 32
//! - max_segments_per_cell = 255
//! - max_synapses_per_segment = 32
//! - activation_threshold = 15 (CONNECTED synapses onto active cells)
//! - learning_threshold = 13 (POTENTIAL synapses onto active cells)
//! (often called `minThreshold` / match threshold in BAMI)
//! - initial_permanence = 0.21
//! - connected_permanence = 0.50
//! - permanence_increment = 0.10
//! - permanence_decrement = 0.10
//! - predicted_segment_decrement = 0.10 (decay for segments that predicted
//! inactive columns; called `predictedSegmentDecrement` in BAMI)
//! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg)
//!
//! Algorithm (one step):
//! Given `active_columns` from the Spatial Pooler, and segment activity
//! caches `active_segments` and `matching_segments` computed *at the end of
//! the previous step*:
//!
//! 1. For each active column:
//! - If it contains any predicted cell (any cell with an active segment
//! from the previous depolarization), mark those cells active and
//! learn on the segment that predicted it.
//! - Else BURST the column: mark all cells in it active, and grow a new
//! segment on the best-matching cell in the column (or, if none,
//! on the cell with the fewest segments).
//! 2. For every column that was predicted but did NOT become active
//! (matching segments on inactive columns), apply the
//! `predicted_segment_decrement` decay so spurious predictions fade.
//! 3. Winner cells = active cells chosen for learning (1 per active column).
//! 4. Compute segment activity for NEXT step:
//! - A segment's CONNECTED activity = #synapses with perm >= connected_perm
//! whose presynaptic cell is in `active_cells`. If >= activation_threshold
//! -> segment is "active" -> its cell is "predicted".
//! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is
//! in `active_cells` (regardless of permanence). If >= learning_threshold
//! -> segment is "matching".
//!
//! Anomaly score = (active columns with no prior predicted cells)
//! / (# active columns).
use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
type CellIdx = u32;
type SegmentIdx = u32;
#[derive(Clone)]
pub struct Synapse {
pub presynaptic_cell: CellIdx,
pub permanence: f32,
}
#[derive(Clone)]
pub struct Segment {
pub cell: CellIdx,
pub synapses: Vec<Synapse>,
/// Cached counters; recomputed each step.
pub num_active_connected: u32,
pub num_active_potential: u32,
/// Simple "last iter touched" stat for least-used cell selection.
pub last_used_iteration: u64,
}
pub struct TemporalMemoryConfig {
pub n_columns: usize,
pub cells_per_column: usize,
pub activation_threshold: u32,
pub learning_threshold: u32,
pub initial_permanence: f32,
pub connected_permanence: f32,
pub permanence_increment: f32,
pub permanence_decrement: f32,
pub predicted_segment_decrement: f32,
pub max_segments_per_cell: usize,
pub max_synapses_per_segment: usize,
pub max_new_synapse_count: usize,
}
impl Default for TemporalMemoryConfig {
fn default() -> Self {
Self {
n_columns: 2048,
cells_per_column: 32,
activation_threshold: 15,
learning_threshold: 13,
initial_permanence: 0.21,
connected_permanence: 0.50,
permanence_increment: 0.10,
permanence_decrement: 0.10,
predicted_segment_decrement: 0.10,
max_segments_per_cell: 255,
max_synapses_per_segment: 32,
max_new_synapse_count: 20,
}
}
}
pub struct TemporalMemory {
pub cfg: TemporalMemoryConfig,
/// All segments in the region. Indexed by SegmentIdx.
pub segments: Vec<Segment>,
/// For each cell, the list of segments that belong to it.
pub cell_segments: Vec<Vec<SegmentIdx>>,
/// Active cells in the current step.
pub active_cells: Vec<bool>,
/// Winner cells (subset of active_cells, 1 per active column) for learning.
pub winner_cells: Vec<bool>,
/// Predictive cells for the current step = cells whose segment became
/// active at the end of the previous step.
pub predictive_cells: Vec<bool>,
/// Cached list of segment indices that were "active" last compute().
active_segments_prev: Vec<SegmentIdx>,
/// Cached list of segment indices that were "matching" last compute().
matching_segments_prev: Vec<SegmentIdx>,
rng: Xoshiro256PlusPlus,
iter_count: u64,
}
impl TemporalMemory {
pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self {
let total = cfg.n_columns * cfg.cells_per_column;
Self {
cell_segments: vec![Vec::new(); total],
active_cells: vec![false; total],
winner_cells: vec![false; total],
predictive_cells: vec![false; total],
cfg,
segments: Vec::new(),
active_segments_prev: Vec::new(),
matching_segments_prev: Vec::new(),
rng: Xoshiro256PlusPlus::seed_from_u64(seed),
iter_count: 0,
}
}
pub fn reset(&mut self) {
for v in self.active_cells.iter_mut() { *v = false; }
for v in self.winner_cells.iter_mut() { *v = false; }
for v in self.predictive_cells.iter_mut() { *v = false; }
self.active_segments_prev.clear();
self.matching_segments_prev.clear();
}
#[inline]
fn col_of(&self, cell: CellIdx) -> usize {
(cell as usize) / self.cfg.cells_per_column
}
#[inline]
fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> {
let base = (col * self.cfg.cells_per_column) as CellIdx;
base..(base + self.cfg.cells_per_column as CellIdx)
}
/// Process one step.
///
/// `active_columns` is the set of column indices activated by the Spatial
/// Pooler this step. Returns the anomaly score in [0, 1].
pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 {
self.iter_count = self.iter_count.wrapping_add(1);
// Snapshot previous-step cell activity (for learning on segments).
let prev_active_cells = self.active_cells.clone();
let prev_winner_cells = self.winner_cells.clone();
// Move current "predictive" (computed at the end of the last step)
// into local variables; we'll overwrite predictive_cells later.
let predictive_prev = self.predictive_cells.clone();
// Group active segments and matching segments by column of their
// owning cell, for the columns that are active this step.
let n_cols = self.cfg.n_columns;
// active_segs_by_col[col] = segment indices whose cell is in col and
// which were "active" in the previous depolarization.
// matching_segs_by_col[col] = similarly for "matching".
let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols];
for &seg in &self.active_segments_prev {
let col = self.col_of(self.segments[seg as usize].cell);
active_segs_by_col[col].push(seg);
}
for &seg in &self.matching_segments_prev {
let col = self.col_of(self.segments[seg as usize].cell);
matching_segs_by_col[col].push(seg);
}
// Columns that are active this step (for O(1) lookup).
let mut active_col_mask = vec![false; n_cols];
for &c in active_columns { active_col_mask[c as usize] = true; }
// Zero out current cell activations.
for v in self.active_cells.iter_mut() { *v = false; }
for v in self.winner_cells.iter_mut() { *v = false; }
// Track anomaly.
let mut unpredicted_cols = 0u32;
// We'll collect (segment, learn_mode) pairs for segment reinforcement
// so we can batch-apply permanence adjustments using prev_active_cells.
// learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched"
enum LearnOp {
Reinforce(SegmentIdx), // correctly predicted
Grow { // bursting column: grow on chosen segment
segment: SegmentIdx,
#[allow(dead_code)]
winner_cell: CellIdx,
},
Punish(SegmentIdx), // matching segment on inactive column
}
let mut ops: Vec<LearnOp> = Vec::new();
// ---- 1) Process active columns ----
for &col in active_columns {
let col = col as usize;
let active_segs = &active_segs_by_col[col];
if !active_segs.is_empty() {
// "Activate predicted column": each cell with an active segment
// becomes active and is a winner; reinforce that segment.
let mut seen_cells: Vec<CellIdx> = Vec::new();
for &seg_i in active_segs {
let seg = &self.segments[seg_i as usize];
let cell = seg.cell;
if !seen_cells.contains(&cell) {
self.active_cells[cell as usize] = true;
self.winner_cells[cell as usize] = true;
seen_cells.push(cell);
}
if learn {
ops.push(LearnOp::Reinforce(seg_i));
}
}
} else {
// ----- BURST -----
unpredicted_cols += 1;
for c in self.cells_in_col(col) {
self.active_cells[c as usize] = true;
}
// Pick a winner cell + segment for learning.
if learn {
let matching = &matching_segs_by_col[col];
let (winner_cell, target_segment) = if !matching.is_empty() {
// Best-matching segment = highest num_active_potential.
let mut best = matching[0];
let mut best_score = self.segments[best as usize].num_active_potential;
for &s in &matching[1..] {
let score = self.segments[s as usize].num_active_potential;
if score > best_score {
best_score = score;
best = s;
}
}
let wc = self.segments[best as usize].cell;
(wc, Some(best))
} else {
// Least-used cell in column, then grow a new segment.
let winner = self.least_used_cell(col);
(winner, None)
};
self.winner_cells[winner_cell as usize] = true;
let segment_id = match target_segment {
Some(s) => s,
None => {
// Create a fresh empty segment on winner cell.
self.create_segment(winner_cell)
}
};
ops.push(LearnOp::Grow { segment: segment_id, winner_cell });
} else {
// No learning: still pick some winner cell (arbitrary)
// so downstream code that inspects winner_cells isn't empty.
let matching = &matching_segs_by_col[col];
let winner_cell = if !matching.is_empty() {
self.segments[matching[0] as usize].cell
} else {
self.least_used_cell(col)
};
self.winner_cells[winner_cell as usize] = true;
}
}
}
// ---- 2) Punish matching segments on INACTIVE columns ----
if learn && self.cfg.predicted_segment_decrement > 0.0 {
for &seg_i in &self.matching_segments_prev {
let col = self.col_of(self.segments[seg_i as usize].cell);
if !active_col_mask[col] {
ops.push(LearnOp::Punish(seg_i));
}
}
}
// ---- 3) Apply learning ----
if learn {
for op in ops {
match op {
LearnOp::Reinforce(seg_i) => {
self.reinforce_segment(seg_i, &prev_active_cells);
// Optionally grow up to N new synapses to winner cells
// of the previous step.
self.grow_synapses_on_segment(seg_i, &prev_winner_cells);
}
LearnOp::Grow { segment, winner_cell: _ } => {
self.reinforce_segment(segment, &prev_active_cells);
self.grow_synapses_on_segment(segment, &prev_winner_cells);
}
LearnOp::Punish(seg_i) => {
let dec = self.cfg.predicted_segment_decrement;
for syn in &mut self.segments[seg_i as usize].synapses {
if prev_active_cells[syn.presynaptic_cell as usize] {
syn.permanence = (syn.permanence - dec).max(0.0);
}
}
}
}
}
}
// ---- 4) Compute segment activity & predictive cells for NEXT step ----
// We have to use the *current* active_cells (just set above).
let mut next_active_segs: Vec<SegmentIdx> = Vec::new();
let mut next_matching_segs: Vec<SegmentIdx> = Vec::new();
for v in self.predictive_cells.iter_mut() { *v = false; }
let conn = self.cfg.connected_permanence;
let act_thr = self.cfg.activation_threshold;
let learn_thr = self.cfg.learning_threshold;
for (seg_i, seg) in self.segments.iter_mut().enumerate() {
let mut n_conn: u32 = 0;
let mut n_pot: u32 = 0;
for syn in &seg.synapses {
if self.active_cells[syn.presynaptic_cell as usize] {
n_pot += 1;
if syn.permanence >= conn { n_conn += 1; }
}
}
seg.num_active_connected = n_conn;
seg.num_active_potential = n_pot;
if n_conn >= act_thr {
next_active_segs.push(seg_i as SegmentIdx);
self.predictive_cells[seg.cell as usize] = true;
}
if n_pot >= learn_thr {
next_matching_segs.push(seg_i as SegmentIdx);
}
}
self.active_segments_prev = next_active_segs;
self.matching_segments_prev = next_matching_segs;
// Keep predictive_prev unused-guard; we no longer need it but
// retained to document intent.
let _ = predictive_prev;
// Anomaly.
if active_columns.is_empty() {
0.0
} else {
(unpredicted_cols as f32) / (active_columns.len() as f32)
}
}
/// Reinforce synapses on `seg`: +inc if presynaptic is active last step,
/// -dec otherwise.
fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) {
let inc = self.cfg.permanence_increment;
let dec = self.cfg.permanence_decrement;
let seg = &mut self.segments[seg_i as usize];
seg.last_used_iteration = self.iter_count;
for syn in &mut seg.synapses {
if prev_active_cells[syn.presynaptic_cell as usize] {
syn.permanence = (syn.permanence + inc).min(1.0);
} else {
syn.permanence = (syn.permanence - dec).max(0.0);
}
}
}
/// Grow up to `max_new_synapse_count - current_potential` new synapses
/// from previous winner cells that are not already connected to this seg.
fn grow_synapses_on_segment(
&mut self,
seg_i: SegmentIdx,
prev_winner_cells: &[bool],
) {
let initial_perm = self.cfg.initial_permanence;
let cap = self.cfg.max_synapses_per_segment;
let max_new = self.cfg.max_new_synapse_count;
// Gather candidate cells (prev winners not already presynaptic to this seg).
let already: Vec<CellIdx> = self.segments[seg_i as usize]
.synapses
.iter()
.map(|s| s.presynaptic_cell)
.collect();
let mut candidates: Vec<CellIdx> = Vec::new();
for (cell_i, &b) in prev_winner_cells.iter().enumerate() {
if b && !already.contains(&(cell_i as CellIdx)) {
candidates.push(cell_i as CellIdx);
}
}
// How many can we add?
let current_len = self.segments[seg_i as usize].synapses.len();
let room = cap.saturating_sub(current_len);
let mut to_add = max_new.min(candidates.len()).min(room);
// Random sample without replacement from candidates.
while to_add > 0 {
let idx = self.rng.gen_range(0..candidates.len());
let pre = candidates.swap_remove(idx);
self.segments[seg_i as usize].synapses.push(Synapse {
presynaptic_cell: pre,
permanence: initial_perm,
});
to_add -= 1;
}
}
fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx {
// Enforce per-cell segment cap by evicting least-recently-used segment
// if necessary.
let cell_segs = &mut self.cell_segments[cell as usize];
if cell_segs.len() >= self.cfg.max_segments_per_cell {
// Find LRU segment.
let (lru_pos, &lru_id) = cell_segs
.iter()
.enumerate()
.min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration)
.expect("cell_segs non-empty");
// Clear that segment in place and reuse its index.
self.segments[lru_id as usize].synapses.clear();
self.segments[lru_id as usize].num_active_connected = 0;
self.segments[lru_id as usize].num_active_potential = 0;
self.segments[lru_id as usize].last_used_iteration = self.iter_count;
// Keep at same position in cell_segs.
let _ = lru_pos;
return lru_id;
}
let new_id = self.segments.len() as SegmentIdx;
self.segments.push(Segment {
cell,
synapses: Vec::with_capacity(self.cfg.max_new_synapse_count),
num_active_connected: 0,
num_active_potential: 0,
last_used_iteration: self.iter_count,
});
cell_segs.push(new_id);
new_id
}
fn least_used_cell(&mut self, col: usize) -> CellIdx {
// Cell with the fewest segments; break ties randomly.
let mut min_segs = usize::MAX;
let mut candidates: Vec<CellIdx> = Vec::new();
for c in self.cells_in_col(col) {
let n = self.cell_segments[c as usize].len();
if n < min_segs {
min_segs = n;
candidates.clear();
candidates.push(c);
} else if n == min_segs {
candidates.push(c);
}
}
let idx = self.rng.gen_range(0..candidates.len());
candidates[idx]
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
#[test]
fn tm_learns_repeating_sequence() {
// Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down.
let cfg = SpatialPoolerConfig::default();
let mut sp = SpatialPooler::new(cfg, 123);
let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456);
// Build 3 fixed random SDRs of 2% sparsity.
let mut rng = Xoshiro256PlusPlus::seed_from_u64(99);
let input_bits = sp.cfg.input_bits;
let make_sdr = |rng: &mut Xoshiro256PlusPlus| {
let mut v = vec![false; input_bits];
let on = (0.02 * input_bits as f32) as usize;
let mut placed = 0;
while placed < on {
let i = rng.gen_range(0..input_bits);
if !v[i] {
v[i] = true;
placed += 1;
}
}
v
};
let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)];
// Warm up SP first so that columns are reliable for each symbol.
for _ in 0..200 {
for s in &seqs {
sp.compute(s, true);
}
}
// Reset TM so prediction state is clean.
tm.reset();
// Record anomaly over a window early and late.
let mut early_anoms: Vec<f32> = Vec::new();
let mut late_anoms: Vec<f32> = Vec::new();
for iter in 0..250 {
for s in &seqs {
let active = sp.compute(s, false);
let anomaly = tm.compute(&active, true);
if iter == 10 { early_anoms.push(anomaly); }
if iter == 249 { late_anoms.push(anomaly); }
}
}
let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32);
let early = mean(&early_anoms);
let late = mean(&late_anoms);
println!("early_anomaly={early}, late_anomaly={late}");
assert!(
late < 0.5 * early + 1e-6,
"late anomaly ({late}) should be < 0.5 * early anomaly ({early})"
);
}
}