Spaces:
Runtime error
Runtime error
| //! Numenta BAMI-spec Temporal Memory. | |
| //! | |
| //! Key parameters (Numenta defaults): | |
| //! - cells_per_column = 32 | |
| //! - max_segments_per_cell = 255 | |
| //! - max_synapses_per_segment = 32 | |
| //! - activation_threshold = 15 (CONNECTED synapses onto active cells) | |
| //! - learning_threshold = 13 (POTENTIAL synapses onto active cells) | |
| //! (often called `minThreshold` / match threshold in BAMI) | |
| //! - initial_permanence = 0.21 | |
| //! - connected_permanence = 0.50 | |
| //! - permanence_increment = 0.10 | |
| //! - permanence_decrement = 0.10 | |
| //! - predicted_segment_decrement = 0.10 (decay for segments that predicted | |
| //! inactive columns; called `predictedSegmentDecrement` in BAMI) | |
| //! - max_new_synapse_count = 20 (max synapses to grow on a new/reinforced seg) | |
| //! | |
| //! Algorithm (one step): | |
| //! Given `active_columns` from the Spatial Pooler, and segment activity | |
| //! caches `active_segments` and `matching_segments` computed *at the end of | |
| //! the previous step*: | |
| //! | |
| //! 1. For each active column: | |
| //! - If it contains any predicted cell (any cell with an active segment | |
| //! from the previous depolarization), mark those cells active and | |
| //! learn on the segment that predicted it. | |
| //! - Else BURST the column: mark all cells in it active, and grow a new | |
| //! segment on the best-matching cell in the column (or, if none, | |
| //! on the cell with the fewest segments). | |
| //! 2. For every column that was predicted but did NOT become active | |
| //! (matching segments on inactive columns), apply the | |
| //! `predicted_segment_decrement` decay so spurious predictions fade. | |
| //! 3. Winner cells = active cells chosen for learning (1 per active column). | |
| //! 4. Compute segment activity for NEXT step: | |
| //! - A segment's CONNECTED activity = #synapses with perm >= connected_perm | |
| //! whose presynaptic cell is in `active_cells`. If >= activation_threshold | |
| //! -> segment is "active" -> its cell is "predicted". | |
| //! - A segment's POTENTIAL activity = #synapses whose presynaptic cell is | |
| //! in `active_cells` (regardless of permanence). If >= learning_threshold | |
| //! -> segment is "matching". | |
| //! | |
| //! Anomaly score = (active columns with no prior predicted cells) | |
| //! / (# active columns). | |
| use rand::Rng; | |
| use rand::SeedableRng; | |
| use rand_xoshiro::Xoshiro256PlusPlus; | |
| type CellIdx = u32; | |
| type SegmentIdx = u32; | |
| pub struct Synapse { | |
| pub presynaptic_cell: CellIdx, | |
| pub permanence: f32, | |
| } | |
| pub struct Segment { | |
| pub cell: CellIdx, | |
| pub synapses: Vec<Synapse>, | |
| /// Cached counters; recomputed each step. | |
| pub num_active_connected: u32, | |
| pub num_active_potential: u32, | |
| /// Simple "last iter touched" stat for least-used cell selection. | |
| pub last_used_iteration: u64, | |
| } | |
| pub struct TemporalMemoryConfig { | |
| pub n_columns: usize, | |
| pub cells_per_column: usize, | |
| pub activation_threshold: u32, | |
| pub learning_threshold: u32, | |
| pub initial_permanence: f32, | |
| pub connected_permanence: f32, | |
| pub permanence_increment: f32, | |
| pub permanence_decrement: f32, | |
| pub predicted_segment_decrement: f32, | |
| pub max_segments_per_cell: usize, | |
| pub max_synapses_per_segment: usize, | |
| pub max_new_synapse_count: usize, | |
| } | |
| impl Default for TemporalMemoryConfig { | |
| fn default() -> Self { | |
| Self { | |
| n_columns: 2048, | |
| cells_per_column: 32, | |
| activation_threshold: 15, | |
| learning_threshold: 13, | |
| initial_permanence: 0.21, | |
| connected_permanence: 0.50, | |
| permanence_increment: 0.10, | |
| permanence_decrement: 0.10, | |
| predicted_segment_decrement: 0.10, | |
| max_segments_per_cell: 255, | |
| max_synapses_per_segment: 32, | |
| max_new_synapse_count: 20, | |
| } | |
| } | |
| } | |
| pub struct TemporalMemory { | |
| pub cfg: TemporalMemoryConfig, | |
| /// All segments in the region. Indexed by SegmentIdx. | |
| pub segments: Vec<Segment>, | |
| /// For each cell, the list of segments that belong to it. | |
| pub cell_segments: Vec<Vec<SegmentIdx>>, | |
| /// Active cells in the current step. | |
| pub active_cells: Vec<bool>, | |
| /// Winner cells (subset of active_cells, 1 per active column) for learning. | |
| pub winner_cells: Vec<bool>, | |
| /// Predictive cells for the current step = cells whose segment became | |
| /// active at the end of the previous step. | |
| pub predictive_cells: Vec<bool>, | |
| /// Cached list of segment indices that were "active" last compute(). | |
| active_segments_prev: Vec<SegmentIdx>, | |
| /// Cached list of segment indices that were "matching" last compute(). | |
| matching_segments_prev: Vec<SegmentIdx>, | |
| rng: Xoshiro256PlusPlus, | |
| iter_count: u64, | |
| } | |
| impl TemporalMemory { | |
| pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self { | |
| let total = cfg.n_columns * cfg.cells_per_column; | |
| Self { | |
| cell_segments: vec![Vec::new(); total], | |
| active_cells: vec![false; total], | |
| winner_cells: vec![false; total], | |
| predictive_cells: vec![false; total], | |
| cfg, | |
| segments: Vec::new(), | |
| active_segments_prev: Vec::new(), | |
| matching_segments_prev: Vec::new(), | |
| rng: Xoshiro256PlusPlus::seed_from_u64(seed), | |
| iter_count: 0, | |
| } | |
| } | |
| pub fn reset(&mut self) { | |
| for v in self.active_cells.iter_mut() { *v = false; } | |
| for v in self.winner_cells.iter_mut() { *v = false; } | |
| for v in self.predictive_cells.iter_mut() { *v = false; } | |
| self.active_segments_prev.clear(); | |
| self.matching_segments_prev.clear(); | |
| } | |
| fn col_of(&self, cell: CellIdx) -> usize { | |
| (cell as usize) / self.cfg.cells_per_column | |
| } | |
| fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> { | |
| let base = (col * self.cfg.cells_per_column) as CellIdx; | |
| base..(base + self.cfg.cells_per_column as CellIdx) | |
| } | |
| /// Process one step. | |
| /// | |
| /// `active_columns` is the set of column indices activated by the Spatial | |
| /// Pooler this step. Returns the anomaly score in [0, 1]. | |
| pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 { | |
| self.iter_count = self.iter_count.wrapping_add(1); | |
| // Snapshot previous-step cell activity (for learning on segments). | |
| let prev_active_cells = self.active_cells.clone(); | |
| let prev_winner_cells = self.winner_cells.clone(); | |
| // Move current "predictive" (computed at the end of the last step) | |
| // into local variables; we'll overwrite predictive_cells later. | |
| let predictive_prev = self.predictive_cells.clone(); | |
| // Group active segments and matching segments by column of their | |
| // owning cell, for the columns that are active this step. | |
| let n_cols = self.cfg.n_columns; | |
| // active_segs_by_col[col] = segment indices whose cell is in col and | |
| // which were "active" in the previous depolarization. | |
| // matching_segs_by_col[col] = similarly for "matching". | |
| let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols]; | |
| let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols]; | |
| for &seg in &self.active_segments_prev { | |
| let col = self.col_of(self.segments[seg as usize].cell); | |
| active_segs_by_col[col].push(seg); | |
| } | |
| for &seg in &self.matching_segments_prev { | |
| let col = self.col_of(self.segments[seg as usize].cell); | |
| matching_segs_by_col[col].push(seg); | |
| } | |
| // Columns that are active this step (for O(1) lookup). | |
| let mut active_col_mask = vec![false; n_cols]; | |
| for &c in active_columns { active_col_mask[c as usize] = true; } | |
| // Zero out current cell activations. | |
| for v in self.active_cells.iter_mut() { *v = false; } | |
| for v in self.winner_cells.iter_mut() { *v = false; } | |
| // Track anomaly. | |
| let mut unpredicted_cols = 0u32; | |
| // We'll collect (segment, learn_mode) pairs for segment reinforcement | |
| // so we can batch-apply permanence adjustments using prev_active_cells. | |
| // learn_mode: "reinforce_correctly_predicted", "punish_incorrectly_matched" | |
| enum LearnOp { | |
| Reinforce(SegmentIdx), // correctly predicted | |
| Grow { // bursting column: grow on chosen segment | |
| segment: SegmentIdx, | |
| winner_cell: CellIdx, | |
| }, | |
| Punish(SegmentIdx), // matching segment on inactive column | |
| } | |
| let mut ops: Vec<LearnOp> = Vec::new(); | |
| // ---- 1) Process active columns ---- | |
| for &col in active_columns { | |
| let col = col as usize; | |
| let active_segs = &active_segs_by_col[col]; | |
| if !active_segs.is_empty() { | |
| // "Activate predicted column": each cell with an active segment | |
| // becomes active and is a winner; reinforce that segment. | |
| let mut seen_cells: Vec<CellIdx> = Vec::new(); | |
| for &seg_i in active_segs { | |
| let seg = &self.segments[seg_i as usize]; | |
| let cell = seg.cell; | |
| if !seen_cells.contains(&cell) { | |
| self.active_cells[cell as usize] = true; | |
| self.winner_cells[cell as usize] = true; | |
| seen_cells.push(cell); | |
| } | |
| if learn { | |
| ops.push(LearnOp::Reinforce(seg_i)); | |
| } | |
| } | |
| } else { | |
| // ----- BURST ----- | |
| unpredicted_cols += 1; | |
| for c in self.cells_in_col(col) { | |
| self.active_cells[c as usize] = true; | |
| } | |
| // Pick a winner cell + segment for learning. | |
| if learn { | |
| let matching = &matching_segs_by_col[col]; | |
| let (winner_cell, target_segment) = if !matching.is_empty() { | |
| // Best-matching segment = highest num_active_potential. | |
| let mut best = matching[0]; | |
| let mut best_score = self.segments[best as usize].num_active_potential; | |
| for &s in &matching[1..] { | |
| let score = self.segments[s as usize].num_active_potential; | |
| if score > best_score { | |
| best_score = score; | |
| best = s; | |
| } | |
| } | |
| let wc = self.segments[best as usize].cell; | |
| (wc, Some(best)) | |
| } else { | |
| // Least-used cell in column, then grow a new segment. | |
| let winner = self.least_used_cell(col); | |
| (winner, None) | |
| }; | |
| self.winner_cells[winner_cell as usize] = true; | |
| let segment_id = match target_segment { | |
| Some(s) => s, | |
| None => { | |
| // Create a fresh empty segment on winner cell. | |
| self.create_segment(winner_cell) | |
| } | |
| }; | |
| ops.push(LearnOp::Grow { segment: segment_id, winner_cell }); | |
| } else { | |
| // No learning: still pick some winner cell (arbitrary) | |
| // so downstream code that inspects winner_cells isn't empty. | |
| let matching = &matching_segs_by_col[col]; | |
| let winner_cell = if !matching.is_empty() { | |
| self.segments[matching[0] as usize].cell | |
| } else { | |
| self.least_used_cell(col) | |
| }; | |
| self.winner_cells[winner_cell as usize] = true; | |
| } | |
| } | |
| } | |
| // ---- 2) Punish matching segments on INACTIVE columns ---- | |
| if learn && self.cfg.predicted_segment_decrement > 0.0 { | |
| for &seg_i in &self.matching_segments_prev { | |
| let col = self.col_of(self.segments[seg_i as usize].cell); | |
| if !active_col_mask[col] { | |
| ops.push(LearnOp::Punish(seg_i)); | |
| } | |
| } | |
| } | |
| // ---- 3) Apply learning ---- | |
| if learn { | |
| for op in ops { | |
| match op { | |
| LearnOp::Reinforce(seg_i) => { | |
| self.reinforce_segment(seg_i, &prev_active_cells); | |
| // Optionally grow up to N new synapses to winner cells | |
| // of the previous step. | |
| self.grow_synapses_on_segment(seg_i, &prev_winner_cells); | |
| } | |
| LearnOp::Grow { segment, winner_cell: _ } => { | |
| self.reinforce_segment(segment, &prev_active_cells); | |
| self.grow_synapses_on_segment(segment, &prev_winner_cells); | |
| } | |
| LearnOp::Punish(seg_i) => { | |
| let dec = self.cfg.predicted_segment_decrement; | |
| for syn in &mut self.segments[seg_i as usize].synapses { | |
| if prev_active_cells[syn.presynaptic_cell as usize] { | |
| syn.permanence = (syn.permanence - dec).max(0.0); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // ---- 4) Compute segment activity & predictive cells for NEXT step ---- | |
| // We have to use the *current* active_cells (just set above). | |
| let mut next_active_segs: Vec<SegmentIdx> = Vec::new(); | |
| let mut next_matching_segs: Vec<SegmentIdx> = Vec::new(); | |
| for v in self.predictive_cells.iter_mut() { *v = false; } | |
| let conn = self.cfg.connected_permanence; | |
| let act_thr = self.cfg.activation_threshold; | |
| let learn_thr = self.cfg.learning_threshold; | |
| for (seg_i, seg) in self.segments.iter_mut().enumerate() { | |
| let mut n_conn: u32 = 0; | |
| let mut n_pot: u32 = 0; | |
| for syn in &seg.synapses { | |
| if self.active_cells[syn.presynaptic_cell as usize] { | |
| n_pot += 1; | |
| if syn.permanence >= conn { n_conn += 1; } | |
| } | |
| } | |
| seg.num_active_connected = n_conn; | |
| seg.num_active_potential = n_pot; | |
| if n_conn >= act_thr { | |
| next_active_segs.push(seg_i as SegmentIdx); | |
| self.predictive_cells[seg.cell as usize] = true; | |
| } | |
| if n_pot >= learn_thr { | |
| next_matching_segs.push(seg_i as SegmentIdx); | |
| } | |
| } | |
| self.active_segments_prev = next_active_segs; | |
| self.matching_segments_prev = next_matching_segs; | |
| // Keep predictive_prev unused-guard; we no longer need it but | |
| // retained to document intent. | |
| let _ = predictive_prev; | |
| // Anomaly. | |
| if active_columns.is_empty() { | |
| 0.0 | |
| } else { | |
| (unpredicted_cols as f32) / (active_columns.len() as f32) | |
| } | |
| } | |
| /// Reinforce synapses on `seg`: +inc if presynaptic is active last step, | |
| /// -dec otherwise. | |
| fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) { | |
| let inc = self.cfg.permanence_increment; | |
| let dec = self.cfg.permanence_decrement; | |
| let seg = &mut self.segments[seg_i as usize]; | |
| seg.last_used_iteration = self.iter_count; | |
| for syn in &mut seg.synapses { | |
| if prev_active_cells[syn.presynaptic_cell as usize] { | |
| syn.permanence = (syn.permanence + inc).min(1.0); | |
| } else { | |
| syn.permanence = (syn.permanence - dec).max(0.0); | |
| } | |
| } | |
| } | |
| /// Grow up to `max_new_synapse_count - current_potential` new synapses | |
| /// from previous winner cells that are not already connected to this seg. | |
| fn grow_synapses_on_segment( | |
| &mut self, | |
| seg_i: SegmentIdx, | |
| prev_winner_cells: &[bool], | |
| ) { | |
| let initial_perm = self.cfg.initial_permanence; | |
| let cap = self.cfg.max_synapses_per_segment; | |
| let max_new = self.cfg.max_new_synapse_count; | |
| // Gather candidate cells (prev winners not already presynaptic to this seg). | |
| let already: Vec<CellIdx> = self.segments[seg_i as usize] | |
| .synapses | |
| .iter() | |
| .map(|s| s.presynaptic_cell) | |
| .collect(); | |
| let mut candidates: Vec<CellIdx> = Vec::new(); | |
| for (cell_i, &b) in prev_winner_cells.iter().enumerate() { | |
| if b && !already.contains(&(cell_i as CellIdx)) { | |
| candidates.push(cell_i as CellIdx); | |
| } | |
| } | |
| // How many can we add? | |
| let current_len = self.segments[seg_i as usize].synapses.len(); | |
| let room = cap.saturating_sub(current_len); | |
| let mut to_add = max_new.min(candidates.len()).min(room); | |
| // Random sample without replacement from candidates. | |
| while to_add > 0 { | |
| let idx = self.rng.gen_range(0..candidates.len()); | |
| let pre = candidates.swap_remove(idx); | |
| self.segments[seg_i as usize].synapses.push(Synapse { | |
| presynaptic_cell: pre, | |
| permanence: initial_perm, | |
| }); | |
| to_add -= 1; | |
| } | |
| } | |
| fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx { | |
| // Enforce per-cell segment cap by evicting least-recently-used segment | |
| // if necessary. | |
| let cell_segs = &mut self.cell_segments[cell as usize]; | |
| if cell_segs.len() >= self.cfg.max_segments_per_cell { | |
| // Find LRU segment. | |
| let (lru_pos, &lru_id) = cell_segs | |
| .iter() | |
| .enumerate() | |
| .min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration) | |
| .expect("cell_segs non-empty"); | |
| // Clear that segment in place and reuse its index. | |
| self.segments[lru_id as usize].synapses.clear(); | |
| self.segments[lru_id as usize].num_active_connected = 0; | |
| self.segments[lru_id as usize].num_active_potential = 0; | |
| self.segments[lru_id as usize].last_used_iteration = self.iter_count; | |
| // Keep at same position in cell_segs. | |
| let _ = lru_pos; | |
| return lru_id; | |
| } | |
| let new_id = self.segments.len() as SegmentIdx; | |
| self.segments.push(Segment { | |
| cell, | |
| synapses: Vec::with_capacity(self.cfg.max_new_synapse_count), | |
| num_active_connected: 0, | |
| num_active_potential: 0, | |
| last_used_iteration: self.iter_count, | |
| }); | |
| cell_segs.push(new_id); | |
| new_id | |
| } | |
| fn least_used_cell(&mut self, col: usize) -> CellIdx { | |
| // Cell with the fewest segments; break ties randomly. | |
| let mut min_segs = usize::MAX; | |
| let mut candidates: Vec<CellIdx> = Vec::new(); | |
| for c in self.cells_in_col(col) { | |
| let n = self.cell_segments[c as usize].len(); | |
| if n < min_segs { | |
| min_segs = n; | |
| candidates.clear(); | |
| candidates.push(c); | |
| } else if n == min_segs { | |
| candidates.push(c); | |
| } | |
| } | |
| let idx = self.rng.gen_range(0..candidates.len()); | |
| candidates[idx] | |
| } | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Tests | |
| // --------------------------------------------------------------------------- | |
| mod tests { | |
| use super::*; | |
| use crate::sp::{SpatialPooler, SpatialPoolerConfig}; | |
| use rand::Rng; | |
| use rand::SeedableRng; | |
| use rand_xoshiro::Xoshiro256PlusPlus; | |
| fn tm_learns_repeating_sequence() { | |
| // Sequence A -> B -> C -> A -> B -> C -> ... should drive anomaly down. | |
| let cfg = SpatialPoolerConfig::default(); | |
| let mut sp = SpatialPooler::new(cfg, 123); | |
| let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456); | |
| // Build 3 fixed random SDRs of 2% sparsity. | |
| let mut rng = Xoshiro256PlusPlus::seed_from_u64(99); | |
| let input_bits = sp.cfg.input_bits; | |
| let make_sdr = |rng: &mut Xoshiro256PlusPlus| { | |
| let mut v = vec![false; input_bits]; | |
| let on = (0.02 * input_bits as f32) as usize; | |
| let mut placed = 0; | |
| while placed < on { | |
| let i = rng.gen_range(0..input_bits); | |
| if !v[i] { | |
| v[i] = true; | |
| placed += 1; | |
| } | |
| } | |
| v | |
| }; | |
| let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)]; | |
| // Warm up SP first so that columns are reliable for each symbol. | |
| for _ in 0..200 { | |
| for s in &seqs { | |
| sp.compute(s, true); | |
| } | |
| } | |
| // Reset TM so prediction state is clean. | |
| tm.reset(); | |
| // Record anomaly over a window early and late. | |
| let mut early_anoms: Vec<f32> = Vec::new(); | |
| let mut late_anoms: Vec<f32> = Vec::new(); | |
| for iter in 0..250 { | |
| for s in &seqs { | |
| let active = sp.compute(s, false); | |
| let anomaly = tm.compute(&active, true); | |
| if iter == 10 { early_anoms.push(anomaly); } | |
| if iter == 249 { late_anoms.push(anomaly); } | |
| } | |
| } | |
| let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32); | |
| let early = mean(&early_anoms); | |
| let late = mean(&late_anoms); | |
| println!("early_anomaly={early}, late_anomaly={late}"); | |
| assert!( | |
| late < 0.5 * early + 1e-6, | |
| "late anomaly ({late}) should be < 0.5 * early anomaly ({early})" | |
| ); | |
| } | |
| } | |