| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use rand::Rng; |
| use rand::SeedableRng; |
| use rand_xoshiro::Xoshiro256PlusPlus; |
|
|
| type CellIdx = u32; |
| type SegmentIdx = u32; |
|
|
| #[derive(Clone)] |
| pub struct Synapse { |
| pub presynaptic_cell: CellIdx, |
| pub permanence: f32, |
| } |
|
|
| #[derive(Clone)] |
| pub struct Segment { |
| pub cell: CellIdx, |
| pub synapses: Vec<Synapse>, |
| |
| pub num_active_connected: u32, |
| pub num_active_potential: u32, |
| |
| pub last_used_iteration: u64, |
| } |
|
|
| pub struct TemporalMemoryConfig { |
| pub n_columns: usize, |
| pub cells_per_column: usize, |
| pub activation_threshold: u32, |
| pub learning_threshold: u32, |
| pub initial_permanence: f32, |
| pub connected_permanence: f32, |
| pub permanence_increment: f32, |
| pub permanence_decrement: f32, |
| pub predicted_segment_decrement: f32, |
| pub max_segments_per_cell: usize, |
| pub max_synapses_per_segment: usize, |
| pub max_new_synapse_count: usize, |
| } |
|
|
| impl Default for TemporalMemoryConfig { |
| fn default() -> Self { |
| Self { |
| n_columns: 2048, |
| cells_per_column: 32, |
| activation_threshold: 15, |
| learning_threshold: 13, |
| initial_permanence: 0.21, |
| connected_permanence: 0.50, |
| permanence_increment: 0.10, |
| permanence_decrement: 0.10, |
| predicted_segment_decrement: 0.10, |
| max_segments_per_cell: 255, |
| max_synapses_per_segment: 32, |
| max_new_synapse_count: 20, |
| } |
| } |
| } |
|
|
| pub struct TemporalMemory { |
| pub cfg: TemporalMemoryConfig, |
| |
| pub segments: Vec<Segment>, |
| |
| pub cell_segments: Vec<Vec<SegmentIdx>>, |
| |
| pub active_cells: Vec<bool>, |
| |
| pub winner_cells: Vec<bool>, |
| |
| |
| pub predictive_cells: Vec<bool>, |
| |
| active_segments_prev: Vec<SegmentIdx>, |
| |
| matching_segments_prev: Vec<SegmentIdx>, |
| rng: Xoshiro256PlusPlus, |
| iter_count: u64, |
| } |
|
|
| impl TemporalMemory { |
| pub fn new(cfg: TemporalMemoryConfig, seed: u64) -> Self { |
| let total = cfg.n_columns * cfg.cells_per_column; |
| Self { |
| cell_segments: vec![Vec::new(); total], |
| active_cells: vec![false; total], |
| winner_cells: vec![false; total], |
| predictive_cells: vec![false; total], |
| cfg, |
| segments: Vec::new(), |
| active_segments_prev: Vec::new(), |
| matching_segments_prev: Vec::new(), |
| rng: Xoshiro256PlusPlus::seed_from_u64(seed), |
| iter_count: 0, |
| } |
| } |
|
|
| pub fn reset(&mut self) { |
| for v in self.active_cells.iter_mut() { *v = false; } |
| for v in self.winner_cells.iter_mut() { *v = false; } |
| for v in self.predictive_cells.iter_mut() { *v = false; } |
| self.active_segments_prev.clear(); |
| self.matching_segments_prev.clear(); |
| } |
|
|
| #[inline] |
| fn col_of(&self, cell: CellIdx) -> usize { |
| (cell as usize) / self.cfg.cells_per_column |
| } |
|
|
| #[inline] |
| fn cells_in_col(&self, col: usize) -> std::ops::Range<CellIdx> { |
| let base = (col * self.cfg.cells_per_column) as CellIdx; |
| base..(base + self.cfg.cells_per_column as CellIdx) |
| } |
|
|
| |
| |
| |
| |
| pub fn compute(&mut self, active_columns: &[u32], learn: bool) -> f32 { |
| self.iter_count = self.iter_count.wrapping_add(1); |
|
|
| |
| let prev_active_cells = self.active_cells.clone(); |
| let prev_winner_cells = self.winner_cells.clone(); |
|
|
| |
| |
| let predictive_prev = self.predictive_cells.clone(); |
|
|
| |
| |
| let n_cols = self.cfg.n_columns; |
|
|
| |
| |
| |
| let mut active_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols]; |
| let mut matching_segs_by_col: Vec<Vec<SegmentIdx>> = vec![Vec::new(); n_cols]; |
| for &seg in &self.active_segments_prev { |
| let col = self.col_of(self.segments[seg as usize].cell); |
| active_segs_by_col[col].push(seg); |
| } |
| for &seg in &self.matching_segments_prev { |
| let col = self.col_of(self.segments[seg as usize].cell); |
| matching_segs_by_col[col].push(seg); |
| } |
|
|
| |
| let mut active_col_mask = vec![false; n_cols]; |
| for &c in active_columns { active_col_mask[c as usize] = true; } |
|
|
| |
| for v in self.active_cells.iter_mut() { *v = false; } |
| for v in self.winner_cells.iter_mut() { *v = false; } |
|
|
| |
| let mut unpredicted_cols = 0u32; |
|
|
| |
| |
| |
| enum LearnOp { |
| Reinforce(SegmentIdx), |
| Grow { |
| segment: SegmentIdx, |
| #[allow(dead_code)] |
| winner_cell: CellIdx, |
| }, |
| Punish(SegmentIdx), |
| } |
| let mut ops: Vec<LearnOp> = Vec::new(); |
|
|
| |
| for &col in active_columns { |
| let col = col as usize; |
| let active_segs = &active_segs_by_col[col]; |
| if !active_segs.is_empty() { |
| |
| |
| let mut seen_cells: Vec<CellIdx> = Vec::new(); |
| for &seg_i in active_segs { |
| let seg = &self.segments[seg_i as usize]; |
| let cell = seg.cell; |
| if !seen_cells.contains(&cell) { |
| self.active_cells[cell as usize] = true; |
| self.winner_cells[cell as usize] = true; |
| seen_cells.push(cell); |
| } |
| if learn { |
| ops.push(LearnOp::Reinforce(seg_i)); |
| } |
| } |
| } else { |
| |
| unpredicted_cols += 1; |
| for c in self.cells_in_col(col) { |
| self.active_cells[c as usize] = true; |
| } |
| |
| if learn { |
| let matching = &matching_segs_by_col[col]; |
| let (winner_cell, target_segment) = if !matching.is_empty() { |
| |
| let mut best = matching[0]; |
| let mut best_score = self.segments[best as usize].num_active_potential; |
| for &s in &matching[1..] { |
| let score = self.segments[s as usize].num_active_potential; |
| if score > best_score { |
| best_score = score; |
| best = s; |
| } |
| } |
| let wc = self.segments[best as usize].cell; |
| (wc, Some(best)) |
| } else { |
| |
| let winner = self.least_used_cell(col); |
| (winner, None) |
| }; |
| self.winner_cells[winner_cell as usize] = true; |
| let segment_id = match target_segment { |
| Some(s) => s, |
| None => { |
| |
| self.create_segment(winner_cell) |
| } |
| }; |
| ops.push(LearnOp::Grow { segment: segment_id, winner_cell }); |
| } else { |
| |
| |
| let matching = &matching_segs_by_col[col]; |
| let winner_cell = if !matching.is_empty() { |
| self.segments[matching[0] as usize].cell |
| } else { |
| self.least_used_cell(col) |
| }; |
| self.winner_cells[winner_cell as usize] = true; |
| } |
| } |
| } |
|
|
| |
| if learn && self.cfg.predicted_segment_decrement > 0.0 { |
| for &seg_i in &self.matching_segments_prev { |
| let col = self.col_of(self.segments[seg_i as usize].cell); |
| if !active_col_mask[col] { |
| ops.push(LearnOp::Punish(seg_i)); |
| } |
| } |
| } |
|
|
| |
| if learn { |
| for op in ops { |
| match op { |
| LearnOp::Reinforce(seg_i) => { |
| self.reinforce_segment(seg_i, &prev_active_cells); |
| |
| |
| self.grow_synapses_on_segment(seg_i, &prev_winner_cells); |
| } |
| LearnOp::Grow { segment, winner_cell: _ } => { |
| self.reinforce_segment(segment, &prev_active_cells); |
| self.grow_synapses_on_segment(segment, &prev_winner_cells); |
| } |
| LearnOp::Punish(seg_i) => { |
| let dec = self.cfg.predicted_segment_decrement; |
| for syn in &mut self.segments[seg_i as usize].synapses { |
| if prev_active_cells[syn.presynaptic_cell as usize] { |
| syn.permanence = (syn.permanence - dec).max(0.0); |
| } |
| } |
| } |
| } |
| } |
| } |
|
|
| |
| |
| let mut next_active_segs: Vec<SegmentIdx> = Vec::new(); |
| let mut next_matching_segs: Vec<SegmentIdx> = Vec::new(); |
| for v in self.predictive_cells.iter_mut() { *v = false; } |
|
|
| let conn = self.cfg.connected_permanence; |
| let act_thr = self.cfg.activation_threshold; |
| let learn_thr = self.cfg.learning_threshold; |
|
|
| for (seg_i, seg) in self.segments.iter_mut().enumerate() { |
| let mut n_conn: u32 = 0; |
| let mut n_pot: u32 = 0; |
| for syn in &seg.synapses { |
| if self.active_cells[syn.presynaptic_cell as usize] { |
| n_pot += 1; |
| if syn.permanence >= conn { n_conn += 1; } |
| } |
| } |
| seg.num_active_connected = n_conn; |
| seg.num_active_potential = n_pot; |
| if n_conn >= act_thr { |
| next_active_segs.push(seg_i as SegmentIdx); |
| self.predictive_cells[seg.cell as usize] = true; |
| } |
| if n_pot >= learn_thr { |
| next_matching_segs.push(seg_i as SegmentIdx); |
| } |
| } |
| self.active_segments_prev = next_active_segs; |
| self.matching_segments_prev = next_matching_segs; |
|
|
| |
| |
| let _ = predictive_prev; |
|
|
| |
| if active_columns.is_empty() { |
| 0.0 |
| } else { |
| (unpredicted_cols as f32) / (active_columns.len() as f32) |
| } |
| } |
|
|
| |
| |
| fn reinforce_segment(&mut self, seg_i: SegmentIdx, prev_active_cells: &[bool]) { |
| let inc = self.cfg.permanence_increment; |
| let dec = self.cfg.permanence_decrement; |
| let seg = &mut self.segments[seg_i as usize]; |
| seg.last_used_iteration = self.iter_count; |
| for syn in &mut seg.synapses { |
| if prev_active_cells[syn.presynaptic_cell as usize] { |
| syn.permanence = (syn.permanence + inc).min(1.0); |
| } else { |
| syn.permanence = (syn.permanence - dec).max(0.0); |
| } |
| } |
| } |
|
|
| |
| |
| fn grow_synapses_on_segment( |
| &mut self, |
| seg_i: SegmentIdx, |
| prev_winner_cells: &[bool], |
| ) { |
| let initial_perm = self.cfg.initial_permanence; |
| let cap = self.cfg.max_synapses_per_segment; |
| let max_new = self.cfg.max_new_synapse_count; |
|
|
| |
| let already: Vec<CellIdx> = self.segments[seg_i as usize] |
| .synapses |
| .iter() |
| .map(|s| s.presynaptic_cell) |
| .collect(); |
| let mut candidates: Vec<CellIdx> = Vec::new(); |
| for (cell_i, &b) in prev_winner_cells.iter().enumerate() { |
| if b && !already.contains(&(cell_i as CellIdx)) { |
| candidates.push(cell_i as CellIdx); |
| } |
| } |
|
|
| |
| let current_len = self.segments[seg_i as usize].synapses.len(); |
| let room = cap.saturating_sub(current_len); |
| let mut to_add = max_new.min(candidates.len()).min(room); |
|
|
| |
| while to_add > 0 { |
| let idx = self.rng.gen_range(0..candidates.len()); |
| let pre = candidates.swap_remove(idx); |
| self.segments[seg_i as usize].synapses.push(Synapse { |
| presynaptic_cell: pre, |
| permanence: initial_perm, |
| }); |
| to_add -= 1; |
| } |
| } |
|
|
| fn create_segment(&mut self, cell: CellIdx) -> SegmentIdx { |
| |
| |
| let cell_segs = &mut self.cell_segments[cell as usize]; |
| if cell_segs.len() >= self.cfg.max_segments_per_cell { |
| |
| let (lru_pos, &lru_id) = cell_segs |
| .iter() |
| .enumerate() |
| .min_by_key(|(_, &sid)| self.segments[sid as usize].last_used_iteration) |
| .expect("cell_segs non-empty"); |
| |
| self.segments[lru_id as usize].synapses.clear(); |
| self.segments[lru_id as usize].num_active_connected = 0; |
| self.segments[lru_id as usize].num_active_potential = 0; |
| self.segments[lru_id as usize].last_used_iteration = self.iter_count; |
| |
| let _ = lru_pos; |
| return lru_id; |
| } |
|
|
| let new_id = self.segments.len() as SegmentIdx; |
| self.segments.push(Segment { |
| cell, |
| synapses: Vec::with_capacity(self.cfg.max_new_synapse_count), |
| num_active_connected: 0, |
| num_active_potential: 0, |
| last_used_iteration: self.iter_count, |
| }); |
| cell_segs.push(new_id); |
| new_id |
| } |
|
|
| fn least_used_cell(&mut self, col: usize) -> CellIdx { |
| |
| let mut min_segs = usize::MAX; |
| let mut candidates: Vec<CellIdx> = Vec::new(); |
| for c in self.cells_in_col(col) { |
| let n = self.cell_segments[c as usize].len(); |
| if n < min_segs { |
| min_segs = n; |
| candidates.clear(); |
| candidates.push(c); |
| } else if n == min_segs { |
| candidates.push(c); |
| } |
| } |
| let idx = self.rng.gen_range(0..candidates.len()); |
| candidates[idx] |
| } |
| } |
|
|
| |
| |
| |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::sp::{SpatialPooler, SpatialPoolerConfig}; |
| use rand::Rng; |
| use rand::SeedableRng; |
| use rand_xoshiro::Xoshiro256PlusPlus; |
|
|
| #[test] |
| fn tm_learns_repeating_sequence() { |
| |
| let cfg = SpatialPoolerConfig::default(); |
| let mut sp = SpatialPooler::new(cfg, 123); |
| let mut tm = TemporalMemory::new(TemporalMemoryConfig::default(), 456); |
|
|
| |
| let mut rng = Xoshiro256PlusPlus::seed_from_u64(99); |
| let input_bits = sp.cfg.input_bits; |
| let make_sdr = |rng: &mut Xoshiro256PlusPlus| { |
| let mut v = vec![false; input_bits]; |
| let on = (0.02 * input_bits as f32) as usize; |
| let mut placed = 0; |
| while placed < on { |
| let i = rng.gen_range(0..input_bits); |
| if !v[i] { |
| v[i] = true; |
| placed += 1; |
| } |
| } |
| v |
| }; |
| let seqs = [make_sdr(&mut rng), make_sdr(&mut rng), make_sdr(&mut rng)]; |
|
|
| |
| for _ in 0..200 { |
| for s in &seqs { |
| sp.compute(s, true); |
| } |
| } |
|
|
| |
| tm.reset(); |
|
|
| |
| let mut early_anoms: Vec<f32> = Vec::new(); |
| let mut late_anoms: Vec<f32> = Vec::new(); |
| for iter in 0..250 { |
| for s in &seqs { |
| let active = sp.compute(s, false); |
| let anomaly = tm.compute(&active, true); |
| if iter == 10 { early_anoms.push(anomaly); } |
| if iter == 249 { late_anoms.push(anomaly); } |
| } |
| } |
|
|
| let mean = |v: &[f32]| v.iter().sum::<f32>() / (v.len() as f32); |
| let early = mean(&early_anoms); |
| let late = mean(&late_anoms); |
| println!("early_anomaly={early}, late_anomaly={late}"); |
| assert!( |
| late < 0.5 * early + 1e-6, |
| "late anomaly ({late}) should be < 0.5 * early anomaly ({early})" |
| ); |
| } |
| } |
|
|