//! Numenta BAMI-spec Temporal Memory. //! //! Key parameters (Numenta defaults): //! - cells_per_column = 32 //! - max_segments_per_cell = 255 //! - max_synapses_per_segment = 32 //! - activation_threshold = 15 (CONNECTED synapses onto active cells) //! - learning_threshold = 13 (POTENTIAL synapses onto active cells) //! (often called `minThreshold` / match threshold in BAMI) //! - initial_permanence = 0.21 //! - connected_permanence = 0.50 //! - permanence_increment = 0.10 //! - permanence_decrement = 0.10 //! - predicted_segment_decrement = 0.10 (decay for segments that predicted //! inactive columns; called `predictedSegmentDecrement` in BAMI) //! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg) //! //! Algorithm (one step): //! Given `active_columns` from the Spatial Pooler, and segment activity //! caches `active_segments` and `matching_segments` computed *at the end of //! the previous step*: //! //! 1. For each active column: //! - If it contains any predicted cell (any cell with an active segment //! from the previous depolarization), mark those cells active and //! learn on the segment that predicted it. //! - Else BURST the column: mark all cells in it active, and grow a new //! segment on the best-matching cell in the column (or, if none, //! on the cell with the fewest segments). //! 2. For every column that was predicted but did NOT become active //! (matching segments on inactive columns), apply the //! `predicted_segment_decrement` decay so spurious predictions fade. //! 3. Winner cells = active cells chosen for learning (1 per active column). //! 4. Compute segment activity for NEXT step: //! - A segment's CONNECTED activity = #synapses with perm >= connected_perm //! whose presynaptic cell is in `active_cells`. If >= activation_threshold //! -> segment is "active" -> its cell is "predicted". //! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is //! in `active_cells` (regardless of permanence). If >= learning_threshold //! -> segment is "matching". //! //! Anomaly score = (active columns with no prior predicted cells) //! / (# active columns). use rand::Rng; use rand::SeedableRng; use rand_xoshiro::Xoshiro256PlusPlus; type CellIdx = u32; type SegmentIdx = u32; #[derive(Clone)] pub struct Synapse { pub presynaptic_cell: CellIdx, pub permanence: f32, } #[derive(Clone)] pub struct Segment { pub cell: CellIdx, pub synapses: Vec, /// Cached counters; recomputed each step. pub num_active_connected: u32, pub num_active_potential: u32, /// Simple "last iter touched" stat for least-used cell selection. pub last_used_iteration: u64, } pub struct TemporalMemoryConfig { pub n_columns: usize, pub cells_per_column: usize, pub activation_threshold: u32, pub learning_threshold: u32, pub initial_permanence: f32, pub connected_permanence: f32, pub permanence_increment: f32, pub permanence_decrement: f32, pub predicted_segment_decrement: f32, pub max_segments_per_cell: usize, pub max_synapses_per_segment: usize, pub max_new_synapse_count: usize, } impl Default for TemporalMemoryConfig { fn default() -> Self { Self { n_columns: 2048, cells_per_column: 32, activation_threshold: 15, learning_threshold: 13, initial_permanence: 0.21, connected_permanence: 0.50, permanence_increment: 0.10, permanence_decrement: 0.10, predicted_segment_decrement: 0.10, max_segments_per_cell: 255, max_synapses_per_segment: 32, max_new_synapse_count: 20, } } } pub struct TemporalMemory { pub cfg: TemporalMemoryConfig, /// All segments in the region. Indexed by SegmentIdx. pub segments: Vec, /// For each cell, the list of segments that belong to it. pub cell_segments: Vec>, /// Active cells in the current step. pub active_cells: Vec, /// Winner cells (subset of active_cells, 1 per active column) for learning. pub winner_cells: Vec, /// Predictive cells for the current step = cells whose segment became /// active at the end of the previous step. pub predictive_cells: Vec, /// Cached list of segment indices that were "active" last compute(). active_segments_prev: Vec, /// Cached list of segment indices that were "matching" last compute(). matching_segments_prev: Vec, rng: Xoshiro256PlusPlus, iter_count: u64, } impl TemporalMemory { pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self { let total = cfg.n_columns * cfg.cells_per_column; Self { cell_segments: vec![Vec::new(); total], active_cells: vec![false; total], winner_cells: vec![false; total], predictive_cells: vec![false; total], cfg, segments: Vec::new(), active_segments_prev: Vec::new(), matching_segments_prev: Vec::new(), rng: Xoshiro256PlusPlus::seed_from_u64(seed), iter_count: 0, } } pub fn reset(&mut self) { for v in self.active_cells.iter_mut() { *v = false; } for v in self.winner_cells.iter_mut() { *v = false; } for v in self.predictive_cells.iter_mut() { *v = false; } self.active_segments_prev.clear(); self.matching_segments_prev.clear(); } #[inline] fn col_of(&self, cell: CellIdx) -> usize { (cell as usize) / self.cfg.cells_per_column } #[inline] fn cells_in_col(&self, col: usize) -> std::ops::Range { let base = (col * self.cfg.cells_per_column) as CellIdx; base..(base + self.cfg.cells_per_column as CellIdx) } /// Process one step. /// /// `active_columns` is the set of column indices activated by the Spatial /// Pooler this step. Returns the anomaly score in [0, 1]. pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 { self.iter_count = self.iter_count.wrapping_add(1); // Snapshot previous-step cell activity (for learning on segments). let prev_active_cells = self.active_cells.clone(); let prev_winner_cells = self.winner_cells.clone(); // Move current "predictive" (computed at the end of the last step) // into local variables; we'll overwrite predictive_cells later. let predictive_prev = self.predictive_cells.clone(); // Group active segments and matching segments by column of their // owning cell, for the columns that are active this step. let n_cols = self.cfg.n_columns; // active_segs_by_col[col] = segment indices whose cell is in col and // which were "active" in the previous depolarization. // matching_segs_by_col[col] = similarly for "matching". let mut active_segs_by_col: Vec> = vec![Vec::new(); n_cols]; let mut matching_segs_by_col: Vec> = vec![Vec::new(); n_cols]; for &seg in &self.active_segments_prev { let col = self.col_of(self.segments[seg as usize].cell); active_segs_by_col[col].push(seg); } for &seg in &self.matching_segments_prev { let col = self.col_of(self.segments[seg as usize].cell); matching_segs_by_col[col].push(seg); } // Columns that are active this step (for O(1) lookup). let mut active_col_mask = vec![false; n_cols]; for &c in active_columns { active_col_mask[c as usize] = true; } // Zero out current cell activations. for v in self.active_cells.iter_mut() { *v = false; } for v in self.winner_cells.iter_mut() { *v = false; } // Track anomaly. let mut unpredicted_cols = 0u32; // We'll collect (segment, learn_mode) pairs for segment reinforcement // so we can batch-apply permanence adjustments using prev_active_cells. // learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched" enum LearnOp { Reinforce(SegmentIdx), // correctly predicted Grow { // bursting column: grow on chosen segment segment: SegmentIdx, #[allow(dead_code)] winner_cell: CellIdx, }, Punish(SegmentIdx), // matching segment on inactive column } let mut ops: Vec = Vec::new(); // ---- 1) Process active columns ---- for &col in active_columns { let col = col as usize; let active_segs = &active_segs_by_col[col]; if !active_segs.is_empty() { // "Activate predicted column": each cell with an active segment // becomes active and is a winner; reinforce that segment. let mut seen_cells: Vec = Vec::new(); for &seg_i in active_segs { let seg = &self.segments[seg_i as usize]; let cell = seg.cell; if !seen_cells.contains(&cell) { self.active_cells[cell as usize] = true; self.winner_cells[cell as usize] = true; seen_cells.push(cell); } if learn { ops.push(LearnOp::Reinforce(seg_i)); } } } else { // ----- BURST ----- unpredicted_cols += 1; for c in self.cells_in_col(col) { self.active_cells[c as usize] = true; } // Pick a winner cell + segment for learning. if learn { let matching = &matching_segs_by_col[col]; let (winner_cell, target_segment) = if !matching.is_empty() { // Best-matching segment = highest num_active_potential. let mut best = matching[0]; let mut best_score = self.segments[best as usize].num_active_potential; for &s in &matching[1..] { let score = self.segments[s as usize].num_active_potential; if score > best_score { best_score = score; best = s; } } let wc = self.segments[best as usize].cell; (wc, Some(best)) } else { // Least-used cell in column, then grow a new segment. let winner = self.least_used_cell(col); (winner, None) }; self.winner_cells[winner_cell as usize] = true; let segment_id = match target_segment { Some(s) => s, None => { // Create a fresh empty segment on winner cell. self.create_segment(winner_cell) } }; ops.push(LearnOp::Grow { segment: segment_id, winner_cell }); } else { // No learning: still pick some winner cell (arbitrary) // so downstream code that inspects winner_cells isn't empty. let matching = &matching_segs_by_col[col]; let winner_cell = if !matching.is_empty() { self.segments[matching[0] as usize].cell } else { self.least_used_cell(col) }; self.winner_cells[winner_cell as usize] = true; } } } // ---- 2) Punish matching segments on INACTIVE columns ---- if learn && self.cfg.predicted_segment_decrement > 0.0 { for &seg_i in &self.matching_segments_prev { let col = self.col_of(self.segments[seg_i as usize].cell); if !active_col_mask[col] { ops.push(LearnOp::Punish(seg_i)); } } } // ---- 3) Apply learning ---- if learn { for op in ops { match op { LearnOp::Reinforce(seg_i) => { self.reinforce_segment(seg_i, &prev_active_cells); // Optionally grow up to N new synapses to winner cells // of the previous step. self.grow_synapses_on_segment(seg_i, &prev_winner_cells); } LearnOp::Grow { segment, winner_cell: _ } => { self.reinforce_segment(segment, &prev_active_cells); self.grow_synapses_on_segment(segment, &prev_winner_cells); } LearnOp::Punish(seg_i) => { let dec = self.cfg.predicted_segment_decrement; for syn in &mut self.segments[seg_i as usize].synapses { if prev_active_cells[syn.presynaptic_cell as usize] { syn.permanence = (syn.permanence - dec).max(0.0); } } } } } } // ---- 4) Compute segment activity & predictive cells for NEXT step ---- // We have to use the *current* active_cells (just set above). let mut next_active_segs: Vec = Vec::new(); let mut next_matching_segs: Vec = Vec::new(); for v in self.predictive_cells.iter_mut() { *v = false; } let conn = self.cfg.connected_permanence; let act_thr = self.cfg.activation_threshold; let learn_thr = self.cfg.learning_threshold; for (seg_i, seg) in self.segments.iter_mut().enumerate() { let mut n_conn: u32 = 0; let mut n_pot: u32 = 0; for syn in &seg.synapses { if self.active_cells[syn.presynaptic_cell as usize] { n_pot += 1; if syn.permanence >= conn { n_conn += 1; } } } seg.num_active_connected = n_conn; seg.num_active_potential = n_pot; if n_conn >= act_thr { next_active_segs.push(seg_i as SegmentIdx); self.predictive_cells[seg.cell as usize] = true; } if n_pot >= learn_thr { next_matching_segs.push(seg_i as SegmentIdx); } } self.active_segments_prev = next_active_segs; self.matching_segments_prev = next_matching_segs; // Keep predictive_prev unused-guard; we no longer need it but // retained to document intent. let _ = predictive_prev; // Anomaly. if active_columns.is_empty() { 0.0 } else { (unpredicted_cols as f32) / (active_columns.len() as f32) } } /// Reinforce synapses on `seg`: +inc if presynaptic is active last step, /// -dec otherwise. fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) { let inc = self.cfg.permanence_increment; let dec = self.cfg.permanence_decrement; let seg = &mut self.segments[seg_i as usize]; seg.last_used_iteration = self.iter_count; for syn in &mut seg.synapses { if prev_active_cells[syn.presynaptic_cell as usize] { syn.permanence = (syn.permanence + inc).min(1.0); } else { syn.permanence = (syn.permanence - dec).max(0.0); } } } /// Grow up to `max_new_synapse_count - current_potential` new synapses /// from previous winner cells that are not already connected to this seg. fn grow_synapses_on_segment( &mut self, seg_i: SegmentIdx, prev_winner_cells: &[bool], ) { let initial_perm = self.cfg.initial_permanence; let cap = self.cfg.max_synapses_per_segment; let max_new = self.cfg.max_new_synapse_count; // Gather candidate cells (prev winners not already presynaptic to this seg). let already: Vec = self.segments[seg_i as usize] .synapses .iter() .map(|s| s.presynaptic_cell) .collect(); let mut candidates: Vec = Vec::new(); for (cell_i, &b) in prev_winner_cells.iter().enumerate() { if b && !already.contains(&(cell_i as CellIdx)) { candidates.push(cell_i as CellIdx); } } // How many can we add? let current_len = self.segments[seg_i as usize].synapses.len(); let room = cap.saturating_sub(current_len); let mut to_add = max_new.min(candidates.len()).min(room); // Random sample without replacement from candidates. while to_add > 0 { let idx = self.rng.gen_range(0..candidates.len()); let pre = candidates.swap_remove(idx); self.segments[seg_i as usize].synapses.push(Synapse { presynaptic_cell: pre, permanence: initial_perm, }); to_add -= 1; } } fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx { // Enforce per-cell segment cap by evicting least-recently-used segment // if necessary. let cell_segs = &mut self.cell_segments[cell as usize]; if cell_segs.len() >= self.cfg.max_segments_per_cell { // Find LRU segment. let (lru_pos, &lru_id) = cell_segs .iter() .enumerate() .min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration) .expect("cell_segs non-empty"); // Clear that segment in place and reuse its index. self.segments[lru_id as usize].synapses.clear(); self.segments[lru_id as usize].num_active_connected = 0; self.segments[lru_id as usize].num_active_potential = 0; self.segments[lru_id as usize].last_used_iteration = self.iter_count; // Keep at same position in cell_segs. let _ = lru_pos; return lru_id; } let new_id = self.segments.len() as SegmentIdx; self.segments.push(Segment { cell, synapses: Vec::with_capacity(self.cfg.max_new_synapse_count), num_active_connected: 0, num_active_potential: 0, last_used_iteration: self.iter_count, }); cell_segs.push(new_id); new_id } fn least_used_cell(&mut self, col: usize) -> CellIdx { // Cell with the fewest segments; break ties randomly. let mut min_segs = usize::MAX; let mut candidates: Vec = Vec::new(); for c in self.cells_in_col(col) { let n = self.cell_segments[c as usize].len(); if n < min_segs { min_segs = n; candidates.clear(); candidates.push(c); } else if n == min_segs { candidates.push(c); } } let idx = self.rng.gen_range(0..candidates.len()); candidates[idx] } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; use crate::sp::{SpatialPooler, SpatialPoolerConfig}; use rand::Rng; use rand::SeedableRng; use rand_xoshiro::Xoshiro256PlusPlus; #[test] fn tm_learns_repeating_sequence() { // Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down. let cfg = SpatialPoolerConfig::default(); let mut sp = SpatialPooler::new(cfg, 123); let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456); // Build 3 fixed random SDRs of 2% sparsity. let mut rng = Xoshiro256PlusPlus::seed_from_u64(99); let input_bits = sp.cfg.input_bits; let make_sdr = |rng: &mut Xoshiro256PlusPlus| { let mut v = vec![false; input_bits]; let on = (0.02 * input_bits as f32) as usize; let mut placed = 0; while placed < on { let i = rng.gen_range(0..input_bits); if !v[i] { v[i] = true; placed += 1; } } v }; let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)]; // Warm up SP first so that columns are reliable for each symbol. for _ in 0..200 { for s in &seqs { sp.compute(s, true); } } // Reset TM so prediction state is clean. tm.reset(); // Record anomaly over a window early and late. let mut early_anoms: Vec = Vec::new(); let mut late_anoms: Vec = Vec::new(); for iter in 0..250 { for s in &seqs { let active = sp.compute(s, false); let anomaly = tm.compute(&active, true); if iter == 10 { early_anoms.push(anomaly); } if iter == 249 { late_anoms.push(anomaly); } } } let mean = |v: &[f32]| v.iter().sum::() / (v.len() as f32); let early = mean(&early_anoms); let late = mean(&late_anoms); println!("early_anomaly={early}, late_anomaly={late}"); assert!( late < 0.5 * early + 1e-6, "late anomaly ({late}) should be < 0.5 * early anomaly ({early})" ); } }