| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use std::sync::Arc; |
|
|
| use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig}; |
| use cudarc::nvrtc::Ptx; |
|
|
| use crate::sp::SpatialPooler; |
|
|
| |
| const PTX_SP_OVERLAP: &str = |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx")); |
| const PTX_SP_TOPK: &str = |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx")); |
| const PTX_SP_LEARN: &str = |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx")); |
| const PTX_SP_DUTY: &str = |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx")); |
| const PTX_SP_BOOST_FUSED: &str = |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx")); |
|
|
| pub struct SpatialPoolerGpu { |
| dev: Arc<CudaDevice>, |
|
|
| |
| input_bits: usize, |
| n_columns: usize, |
| synapses_per_col: usize, |
| conn_thr: f32, |
| inc: f32, |
| dec: f32, |
| sparsity: f32, |
| duty_period: f32, |
| boost_strength: f32, |
|
|
| |
| syn_bit: CudaSlice<u32>, |
| syn_perm: CudaSlice<f32>, |
| boost: CudaSlice<f32>, |
| active_duty: CudaSlice<f32>, |
| overlap_duty: CudaSlice<f32>, |
|
|
| |
| inp_dev: CudaSlice<u8>, |
| raw: CudaSlice<u32>, |
| boosted: CudaSlice<f32>, |
| active_mask: CudaSlice<u8>, |
|
|
| |
| host_mask: Vec<u8>, |
|
|
| |
| |
| |
| strict_parity: bool, |
| } |
|
|
| impl SpatialPoolerGpu { |
| |
| |
| |
| pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> { |
| let dev = CudaDevice::new(0)?; |
| let cfg = &cpu.cfg; |
| let n = cfg.n_columns; |
| let s = cfg.potential_synapses; |
|
|
| |
| let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s); |
| let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s); |
| for col in &cpu.columns { |
| debug_assert_eq!(col.inputs.len(), s); |
| debug_assert_eq!(col.perms.len(), s); |
| syn_bit_h.extend_from_slice(&col.inputs); |
| syn_perm_h.extend_from_slice(&col.perms); |
| } |
|
|
| let syn_bit = dev.htod_sync_copy(&syn_bit_h)?; |
| let syn_perm = dev.htod_sync_copy(&syn_perm_h)?; |
| let boost = dev.htod_sync_copy(&cpu.boost)?; |
| let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?; |
| let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?; |
|
|
| let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?; |
| let raw: CudaSlice<u32> = dev.alloc_zeros(n)?; |
| let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?; |
| let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?; |
|
|
| |
| |
| |
| |
| |
| let modules = [ |
| ("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"), |
| ("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"), |
| ("htm_sp_learn", PTX_SP_LEARN, "sp_learn"), |
| ("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"), |
| ("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"), |
| ]; |
| for (modname, ptx, fnname) in modules { |
| |
| |
| if dev.get_func(modname, fnname).is_none() { |
| dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?; |
| } |
| } |
|
|
| Ok(Self { |
| dev, |
| input_bits: cfg.input_bits, |
| n_columns: n, |
| synapses_per_col: s, |
| conn_thr: cfg.connected_threshold, |
| inc: cfg.syn_perm_active_inc, |
| dec: cfg.syn_perm_inactive_dec, |
| sparsity: cfg.sparsity, |
| duty_period: cfg.duty_cycle_period, |
| boost_strength: cfg.boost_strength, |
| syn_bit, |
| syn_perm, |
| boost, |
| active_duty, |
| overlap_duty, |
| inp_dev, |
| raw, |
| boosted, |
| active_mask, |
| host_mask: vec![0u8; n], |
| strict_parity: false, |
| }) |
| } |
|
|
| |
| pub fn set_strict_parity(&mut self, strict: bool) { |
| self.strict_parity = strict; |
| } |
|
|
| |
| pub fn dev_ref(&self) -> &Arc<CudaDevice> { |
| &self.dev |
| } |
|
|
| |
| pub fn n_columns_accessor(&self) -> usize { self.n_columns } |
| #[allow(dead_code)] |
| pub fn input_bits_accessor(&self) -> usize { self.input_bits } |
| pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col } |
| pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr } |
| pub fn inc_accessor(&self) -> f32 { self.inc } |
| pub fn dec_accessor(&self) -> f32 { self.dec } |
| pub fn sparsity_accessor(&self) -> f32 { self.sparsity } |
| pub fn duty_period_accessor(&self) -> f32 { self.duty_period } |
| #[allow(dead_code)] |
| pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength } |
|
|
| pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit } |
| pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm } |
| pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost } |
| pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty } |
|
|
| |
| |
| |
| |
| |
| pub fn initial_threshold_estimate(&self) -> f32 { |
| |
| |
| |
| |
| |
| 2.0f32 |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[allow(clippy::too_many_arguments)] |
| pub fn step_batch( |
| &mut self, |
| inputs_flat_dev: &CudaSlice<u8>, |
| t: usize, |
| input_bits: usize, |
| learn: bool, |
| cols_out: &mut [u8], |
| active_indices_host: &mut Vec<u32>, |
| ) -> Result<(), DriverError> { |
| let n = self.n_columns; |
| let k = ((self.sparsity * n as f32).round() as usize).max(1); |
| debug_assert_eq!(cols_out.len(), t * n); |
|
|
| let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap(); |
| let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap(); |
| let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap(); |
| let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap(); |
|
|
| let overlap_cfg = LaunchConfig { |
| grid_dim: (n as u32, 1, 1), |
| block_dim: (128, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| let topk_cfg = LaunchConfig { |
| grid_dim: (1, 1, 1), |
| block_dim: (256, 1, 1), |
| shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32, |
| }; |
| let learn_cfg = overlap_cfg; |
| let duty_cfg = LaunchConfig { |
| grid_dim: ((n as u32 + 255) / 256, 1, 1), |
| block_dim: (256, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| let alpha = 1.0f32 / self.duty_period.max(1.0); |
|
|
| |
| self.host_mask.resize(n, 0); |
|
|
| active_indices_host.clear(); |
|
|
| for ti in 0..t { |
| |
| |
| |
| |
| |
| let in_off = ti * input_bits; |
| |
| let sub = inputs_flat_dev.slice(in_off..in_off + input_bits); |
| self.dev.dtod_copy(&sub, &mut self.inp_dev)?; |
|
|
| |
| unsafe { |
| overlap_fn.clone().launch( |
| overlap_cfg, |
| ( |
| &self.inp_dev, |
| &self.syn_bit, |
| &self.syn_perm, |
| &self.boost, |
| self.conn_thr, |
| self.synapses_per_col as u32, |
| n as u32, |
| &mut self.raw, |
| &mut self.boosted, |
| ), |
| )?; |
| } |
|
|
| |
| self.dev.memset_zeros(&mut self.active_mask)?; |
| unsafe { |
| topk_fn.clone().launch( |
| topk_cfg, |
| (&self.boosted, n as u32, k as u32, &mut self.active_mask), |
| )?; |
| } |
|
|
| |
| if learn { |
| unsafe { |
| learn_fn.clone().launch( |
| learn_cfg, |
| ( |
| &self.active_mask, |
| &self.inp_dev, |
| &self.syn_bit, |
| &mut self.syn_perm, |
| self.inc, |
| self.dec, |
| self.synapses_per_col as u32, |
| n as u32, |
| ), |
| )?; |
| } |
| } |
|
|
| |
| unsafe { |
| duty_fn.clone().launch( |
| duty_cfg, |
| ( |
| &self.active_mask, |
| &self.raw, |
| &mut self.active_duty, |
| &mut self.overlap_duty, |
| &mut self.boost, |
| alpha, |
| 1.0f32, |
| 0.0f32, |
| 0.0f32, |
| 0u32, |
| n as u32, |
| ), |
| )?; |
| } |
|
|
| |
| |
| |
| |
| if learn && self.boost_strength > 0.0 { |
| if self.strict_parity { |
| let mut duty_host = vec![0f32; n]; |
| self.dev |
| .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?; |
| let sum: f32 = duty_host.iter().sum(); |
| let mean = sum / (n as f32); |
| let mut boost_host = vec![0f32; n]; |
| for i in 0..n { |
| boost_host[i] = |
| (-self.boost_strength * (duty_host[i] - mean)).exp(); |
| } |
| self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?; |
|
|
| |
| let mut ov_host = vec![0f32; n]; |
| self.dev |
| .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?; |
| let max_ov = ov_host.iter().cloned().fold(0f32, f32::max); |
| if max_ov > 0.0 { |
| let thr = 0.001f32 * max_ov; |
| let bump = self.inc * 0.1f32; |
| let bump_cols: Vec<u32> = ov_host |
| .iter() |
| .enumerate() |
| .filter_map(|(i, &o)| { |
| if o < thr { Some(i as u32) } else { None } |
| }) |
| .collect(); |
| if !bump_cols.is_empty() { |
| let s = self.synapses_per_col; |
| let mut perm_host = vec![0f32; n * s]; |
| self.dev |
| .dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?; |
| for &c in &bump_cols { |
| let base = (c as usize) * s; |
| for p in &mut perm_host[base..base + s] { |
| *p = (*p + bump).min(1.0); |
| } |
| } |
| self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?; |
| } |
| } |
| } else { |
| |
| |
| let boost_fn = self |
| .dev |
| .get_func("htm_sp_boost_fused", "sp_boost_from_duty") |
| .expect("sp_boost_fused not loaded"); |
| let boost_cfg = LaunchConfig { |
| grid_dim: (1, 1, 1), |
| block_dim: (1024, 1, 1), |
| shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32, |
| }; |
| unsafe { |
| boost_fn.launch( |
| boost_cfg, |
| ( |
| &self.active_duty, |
| &mut self.boost, |
| self.boost_strength, |
| n as u32, |
| ), |
| )?; |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| self.dev |
| .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?; |
| let co = ti * n; |
| cols_out[co..co + n].copy_from_slice(&self.host_mask); |
| |
| for (i, &b) in self.host_mask.iter().enumerate() { |
| if b != 0 { |
| active_indices_host.push(i as u32); |
| } |
| } |
| |
| active_indices_host.push(u32::MAX); |
| } |
|
|
| Ok(()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[allow(clippy::too_many_arguments)] |
| pub fn step_batch_with_tm( |
| &mut self, |
| inputs_flat_dev: &CudaSlice<u8>, |
| t: usize, |
| input_bits: usize, |
| learn: bool, |
| cols_dev: &mut CudaSlice<u8>, |
| anom_dev: &mut CudaSlice<f32>, |
| tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu, |
| ) -> Result<(), DriverError> { |
| let n = self.n_columns; |
| let k = ((self.sparsity * n as f32).round() as usize).max(1); |
| debug_assert_eq!(cols_dev.len(), t * n); |
| debug_assert_eq!(anom_dev.len(), t); |
|
|
| let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap(); |
| let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap(); |
| let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap(); |
| let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap(); |
|
|
| let overlap_cfg = LaunchConfig { |
| grid_dim: (n as u32, 1, 1), |
| block_dim: (128, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| let topk_cfg = LaunchConfig { |
| grid_dim: (1, 1, 1), |
| block_dim: (256, 1, 1), |
| shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32, |
| }; |
| let learn_cfg = overlap_cfg; |
| let duty_cfg = LaunchConfig { |
| grid_dim: ((n as u32 + 255) / 256, 1, 1), |
| block_dim: (256, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| let alpha = 1.0f32 / self.duty_period.max(1.0); |
|
|
| for ti in 0..t { |
| let in_off = ti * input_bits; |
| let sub = inputs_flat_dev.slice(in_off..in_off + input_bits); |
| self.dev.dtod_copy(&sub, &mut self.inp_dev)?; |
|
|
| |
| unsafe { |
| overlap_fn.clone().launch( |
| overlap_cfg, |
| ( |
| &self.inp_dev, |
| &self.syn_bit, |
| &self.syn_perm, |
| &self.boost, |
| self.conn_thr, |
| self.synapses_per_col as u32, |
| n as u32, |
| &mut self.raw, |
| &mut self.boosted, |
| ), |
| )?; |
| } |
|
|
| |
| self.dev.memset_zeros(&mut self.active_mask)?; |
| unsafe { |
| topk_fn.clone().launch( |
| topk_cfg, |
| (&self.boosted, n as u32, k as u32, &mut self.active_mask), |
| )?; |
| } |
|
|
| |
| if learn { |
| unsafe { |
| learn_fn.clone().launch( |
| learn_cfg, |
| ( |
| &self.active_mask, |
| &self.inp_dev, |
| &self.syn_bit, |
| &mut self.syn_perm, |
| self.inc, |
| self.dec, |
| self.synapses_per_col as u32, |
| n as u32, |
| ), |
| )?; |
| } |
| } |
|
|
| |
| unsafe { |
| duty_fn.clone().launch( |
| duty_cfg, |
| ( |
| &self.active_mask, |
| &self.raw, |
| &mut self.active_duty, |
| &mut self.overlap_duty, |
| &mut self.boost, |
| alpha, |
| 1.0f32, |
| 0.0f32, |
| 0.0f32, |
| 0u32, |
| n as u32, |
| ), |
| )?; |
| } |
|
|
| |
| if learn && self.boost_strength > 0.0 { |
| let boost_fn = self.dev |
| .get_func("htm_sp_boost_fused", "sp_boost_from_duty") |
| .expect("sp_boost_fused not loaded"); |
| let boost_cfg = LaunchConfig { |
| grid_dim: (1, 1, 1), |
| block_dim: (1024, 1, 1), |
| shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32, |
| }; |
| unsafe { |
| boost_fn.launch( |
| boost_cfg, |
| ( |
| &self.active_duty, |
| &mut self.boost, |
| self.boost_strength, |
| n as u32, |
| ), |
| )?; |
| } |
| } |
|
|
| |
| let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n); |
| self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?; |
|
|
| |
| tm.step(&self.active_mask, anom_dev, ti as u32, learn)?; |
| } |
|
|
| Ok(()) |
| } |
|
|
| |
| pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> { |
| debug_assert_eq!(input.len(), self.input_bits); |
| let n = self.n_columns; |
| let k = ((self.sparsity * n as f32).round() as usize).max(1); |
|
|
| |
| self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?; |
|
|
| |
| let overlap_fn = self |
| .dev |
| .get_func("htm_sp_overlap", "sp_overlap") |
| .expect("sp_overlap not loaded"); |
| let overlap_cfg = LaunchConfig { |
| grid_dim: (n as u32, 1, 1), |
| block_dim: (128, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| unsafe { |
| overlap_fn.launch( |
| overlap_cfg, |
| ( |
| &self.inp_dev, |
| &self.syn_bit, |
| &self.syn_perm, |
| &self.boost, |
| self.conn_thr, |
| self.synapses_per_col as u32, |
| n as u32, |
| &mut self.raw, |
| &mut self.boosted, |
| ), |
| )?; |
| } |
|
|
| |
| let topk_fn = self |
| .dev |
| .get_func("htm_sp_topk", "sp_topk_select") |
| .expect("sp_topk not loaded"); |
| let topk_cfg = LaunchConfig { |
| grid_dim: (1, 1, 1), |
| block_dim: (256, 1, 1), |
| shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32, |
| }; |
| |
| |
| self.dev.memset_zeros(&mut self.active_mask)?; |
| unsafe { |
| topk_fn.launch( |
| topk_cfg, |
| ( |
| &self.boosted, |
| n as u32, |
| k as u32, |
| &mut self.active_mask, |
| ), |
| )?; |
| } |
|
|
| |
| if learn { |
| let learn_fn = self |
| .dev |
| .get_func("htm_sp_learn", "sp_learn") |
| .expect("sp_learn not loaded"); |
| let learn_cfg = LaunchConfig { |
| grid_dim: (n as u32, 1, 1), |
| block_dim: (128, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| unsafe { |
| learn_fn.launch( |
| learn_cfg, |
| ( |
| &self.active_mask, |
| &self.inp_dev, |
| &self.syn_bit, |
| &mut self.syn_perm, |
| self.inc, |
| self.dec, |
| self.synapses_per_col as u32, |
| n as u32, |
| ), |
| )?; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| let alpha = 1.0f32 / self.duty_period.max(1.0); |
| let duty_fn = self |
| .dev |
| .get_func("htm_sp_duty", "sp_duty_update") |
| .expect("sp_duty not loaded"); |
| let duty_cfg = LaunchConfig { |
| grid_dim: ((n as u32 + 255) / 256, 1, 1), |
| block_dim: (256, 1, 1), |
| shared_mem_bytes: 0, |
| }; |
| |
| unsafe { |
| duty_fn.launch( |
| duty_cfg, |
| ( |
| &self.active_mask, |
| &self.raw, |
| &mut self.active_duty, |
| &mut self.overlap_duty, |
| &mut self.boost, |
| alpha, |
| 1.0f32, |
| 0.0f32, |
| 0.0f32, |
| 0u32, |
| n as u32, |
| ), |
| )?; |
| } |
|
|
| if learn && self.boost_strength > 0.0 && self.strict_parity { |
| |
| |
| |
| |
| let mut duty_host = vec![0f32; n]; |
| self.dev |
| .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?; |
| let sum: f32 = duty_host.iter().sum(); |
| let mean = sum / (n as f32); |
| let mut boost_host = vec![0f32; n]; |
| for i in 0..n { |
| boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp(); |
| } |
| self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?; |
|
|
| |
| |
| |
| |
| |
| let mut ov_host = vec![0f32; n]; |
| self.dev |
| .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?; |
| let max_ov = ov_host.iter().cloned().fold(0f32, f32::max); |
| if max_ov > 0.0 { |
| let thr = 0.001f32 * max_ov; |
| let bump = self.inc * 0.1f32; |
| |
| |
| let bump_cols: Vec<u32> = ov_host |
| .iter() |
| .enumerate() |
| .filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None }) |
| .collect(); |
| if !bump_cols.is_empty() { |
| |
| |
| let s = self.synapses_per_col; |
| let mut perm_host = vec![0f32; n * s]; |
| self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?; |
| for &c in &bump_cols { |
| let base = (c as usize) * s; |
| for p in &mut perm_host[base..base + s] { |
| *p = (*p + bump).min(1.0); |
| } |
| } |
| self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?; |
| } |
| } |
| } else if learn && self.boost_strength > 0.0 { |
| |
| let mut duty_host = vec![0f32; n]; |
| self.dev |
| .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?; |
| let sum: f32 = duty_host.iter().sum(); |
| let mean = sum / (n as f32); |
| let boost_fn = self |
| .dev |
| .get_func("htm_sp_duty", "sp_duty_update") |
| .expect("sp_duty not loaded"); |
| unsafe { |
| boost_fn.launch( |
| duty_cfg, |
| ( |
| &self.active_mask, |
| &self.raw, |
| &mut self.active_duty, |
| &mut self.overlap_duty, |
| &mut self.boost, |
| 0.0f32, |
| 1.0f32, |
| self.boost_strength, |
| mean, |
| 1u32, |
| n as u32, |
| ), |
| )?; |
| } |
| } |
|
|
| |
| self.dev |
| .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?; |
| let mut active: Vec<u32> = Vec::with_capacity(k); |
| for (i, &b) in self.host_mask.iter().enumerate() { |
| if b != 0 { |
| active.push(i as u32); |
| } |
| } |
| debug_assert_eq!(active.len(), k, "SP must emit exactly k winners"); |
| Ok(active) |
| } |
| } |
|
|