Spaces:
Runtime error
Runtime error
| //! GPU implementation of the Spatial Pooler. | |
| //! | |
| //! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX | |
| //! kernels. `compute(input, learn)` performs one SP step and returns the | |
| //! sorted active-column indices (host `Vec<u32>`) — this is what the CPU | |
| //! TemporalMemory consumes. | |
| //! | |
| //! Persistent state on device (per region): | |
| //! syn_bit : u32 [n_columns × S] (constant after init) | |
| //! syn_perm : f32 [n_columns × S] (updated by sp_learn) | |
| //! boost : f32 [n_columns] | |
| //! active_duty : f32 [n_columns] | |
| //! overlap_duty: f32 [n_columns] | |
| //! | |
| //! Per-step transient state: | |
| //! inp_dev : u8 [input_bits] (H2D copy each step) | |
| //! raw : u32 [n_columns] | |
| //! boosted : f32 [n_columns] | |
| //! active_mask : u8 [n_columns] (topk output, D2H at the end) | |
| use std::sync::Arc; | |
| use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig}; | |
| use cudarc::nvrtc::Ptx; | |
| use crate::sp::SpatialPooler; | |
| // Embed PTX at compile time. OUT_DIR is set by build.rs. | |
| const PTX_SP_OVERLAP: &str = | |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx")); | |
| const PTX_SP_TOPK: &str = | |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx")); | |
| const PTX_SP_LEARN: &str = | |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx")); | |
| const PTX_SP_DUTY: &str = | |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx")); | |
| const PTX_SP_BOOST_FUSED: &str = | |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx")); | |
| pub struct SpatialPoolerGpu { | |
| dev: Arc<CudaDevice>, | |
| // Config mirror (we don't touch CPU SpatialPooler after init). | |
| input_bits: usize, | |
| n_columns: usize, | |
| synapses_per_col: usize, | |
| conn_thr: f32, | |
| inc: f32, | |
| dec: f32, | |
| sparsity: f32, | |
| duty_period: f32, | |
| boost_strength: f32, | |
| // Persistent device state. | |
| syn_bit: CudaSlice<u32>, | |
| syn_perm: CudaSlice<f32>, | |
| boost: CudaSlice<f32>, | |
| active_duty: CudaSlice<f32>, | |
| overlap_duty: CudaSlice<f32>, | |
| // Transient scratch (reused each step). | |
| inp_dev: CudaSlice<u8>, | |
| raw: CudaSlice<u32>, | |
| boosted: CudaSlice<f32>, | |
| active_mask: CudaSlice<u8>, | |
| // Reusable host buffer for D2H of active_mask. | |
| host_mask: Vec<u8>, | |
| /// Strict bit-parity with CPU reference. Enabled for tests. | |
| /// Forces host-side boost/exp computation and the overlap-duty bump check | |
| /// every step. Default false for max throughput. | |
| strict_parity: bool, | |
| } | |
| impl SpatialPoolerGpu { | |
| /// Copy CPU SpatialPooler state onto the device. This preserves the | |
| /// exact seeded proximal synapse layout + initial permanences, so the | |
| /// GPU SP is a bit-identical parallel implementation of the CPU SP. | |
| pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> { | |
| let dev = CudaDevice::new(0)?; | |
| let cfg = &cpu.cfg; | |
| let n = cfg.n_columns; | |
| let s = cfg.potential_synapses; | |
| // Flatten proximal dendrites into column-major arrays. | |
| let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s); | |
| let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s); | |
| for col in &cpu.columns { | |
| debug_assert_eq!(col.inputs.len(), s); | |
| debug_assert_eq!(col.perms.len(), s); | |
| syn_bit_h.extend_from_slice(&col.inputs); | |
| syn_perm_h.extend_from_slice(&col.perms); | |
| } | |
| let syn_bit = dev.htod_sync_copy(&syn_bit_h)?; | |
| let syn_perm = dev.htod_sync_copy(&syn_perm_h)?; | |
| let boost = dev.htod_sync_copy(&cpu.boost)?; | |
| let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?; | |
| let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?; | |
| let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?; | |
| let raw: CudaSlice<u32> = dev.alloc_zeros(n)?; | |
| let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?; | |
| let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?; | |
| // Load PTX modules. Each .ptx is a module containing one `extern "C"` | |
| // function; we tag them by unique module names so multiple SP instances | |
| // don't collide (cudarc uses the (module, func) pair). | |
| // Actually: CudaDevice::load_ptx stores under the given module name | |
| // globally on the device, so we use a deterministic naming scheme. | |
| let modules = [ | |
| ("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"), | |
| ("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"), | |
| ("htm_sp_learn", PTX_SP_LEARN, "sp_learn"), | |
| ("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"), | |
| ("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"), | |
| ]; | |
| for (modname, ptx, fnname) in modules { | |
| // load_ptx is NOT idempotent — calling twice errors. For multi-region | |
| // support we check-then-load. | |
| if dev.get_func(modname, fnname).is_none() { | |
| dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?; | |
| } | |
| } | |
| Ok(Self { | |
| dev, | |
| input_bits: cfg.input_bits, | |
| n_columns: n, | |
| synapses_per_col: s, | |
| conn_thr: cfg.connected_threshold, | |
| inc: cfg.syn_perm_active_inc, | |
| dec: cfg.syn_perm_inactive_dec, | |
| sparsity: cfg.sparsity, | |
| duty_period: cfg.duty_cycle_period, | |
| boost_strength: cfg.boost_strength, | |
| syn_bit, | |
| syn_perm, | |
| boost, | |
| active_duty, | |
| overlap_duty, | |
| inp_dev, | |
| raw, | |
| boosted, | |
| active_mask, | |
| host_mask: vec![0u8; n], | |
| strict_parity: false, | |
| }) | |
| } | |
| /// Enable strict bit-parity mode. Parity tests use this. | |
| pub fn set_strict_parity(&mut self, strict: bool) { | |
| self.strict_parity = strict; | |
| } | |
| /// Access to the underlying CudaDevice for host-side orchestration. | |
| pub fn dev_ref(&self) -> &Arc<CudaDevice> { | |
| &self.dev | |
| } | |
| // --- Fused-path accessors (immutable state reads + pointer-grabs). --- | |
| pub fn n_columns_accessor(&self) -> usize { self.n_columns } | |
| pub fn input_bits_accessor(&self) -> usize { self.input_bits } | |
| pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col } | |
| pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr } | |
| pub fn inc_accessor(&self) -> f32 { self.inc } | |
| pub fn dec_accessor(&self) -> f32 { self.dec } | |
| pub fn sparsity_accessor(&self) -> f32 { self.sparsity } | |
| pub fn duty_period_accessor(&self) -> f32 { self.duty_period } | |
| pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength } | |
| pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit } | |
| pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm } | |
| pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost } | |
| pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty } | |
| /// Compute the 95th-percentile-like initial threshold from raw overlaps | |
| /// after a short warmup pass. Used to seed `inhibition_threshold` such | |
| /// that activation rate starts near the sparsity target. | |
| /// Placeholder (returns a conservative constant); real warmup pass | |
| /// happens on the Rust orchestrator side. | |
| pub fn initial_threshold_estimate(&self) -> f32 { | |
| // With conn_thr=0.5, init_perm around 0.5±0.1, S=40, sparse SDR at 2%: | |
| // expected overlap ~ 40 * 0.02 = 0.8 connected hits → boosted ~ 0.8. | |
| // Top-K selects top 2%, so threshold for top 2% is roughly the | |
| // 98th-percentile of boosted. Conservative start: 2.0. | |
| // The per-column adaptation will quickly steer each column's thr. | |
| 2.0f32 | |
| } | |
| /// Batched multi-step SP on the GPU. Processes T timesteps from a | |
| /// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column | |
| /// mask to `cols_dev_out` and `(T,)` active column index list (in a | |
| /// per-step window of size k, padded with u32::MAX). | |
| /// | |
| /// For each step, this runs the same 5-kernel pipeline as `compute`, but | |
| /// skips the per-step boost/duty D2H→exp→H2D round-trip: instead it | |
| /// accumulates to a host scratch once every `boost_interval` steps. | |
| /// | |
| /// This is the fast path used by `HTMRegionGpu.step_many_gpu`. | |
| pub fn step_batch( | |
| &mut self, | |
| inputs_flat_dev: &CudaSlice<u8>, | |
| t: usize, | |
| input_bits: usize, | |
| learn: bool, | |
| cols_out: &mut [u8], | |
| active_indices_host: &mut Vec<u32>, | |
| ) -> Result<(), DriverError> { | |
| let n = self.n_columns; | |
| let k = ((self.sparsity * n as f32).round() as usize).max(1); | |
| debug_assert_eq!(cols_out.len(), t * n); | |
| let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap(); | |
| let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap(); | |
| let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap(); | |
| let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap(); | |
| let overlap_cfg = LaunchConfig { | |
| grid_dim: (n as u32, 1, 1), | |
| block_dim: (128, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| let topk_cfg = LaunchConfig { | |
| grid_dim: (1, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32, | |
| }; | |
| let learn_cfg = overlap_cfg; | |
| let duty_cfg = LaunchConfig { | |
| grid_dim: ((n as u32 + 255) / 256, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| let alpha = 1.0f32 / self.duty_period.max(1.0); | |
| // Reusable host buffer for the per-step active_mask D2H. | |
| self.host_mask.resize(n, 0); | |
| active_indices_host.clear(); | |
| for ti in 0..t { | |
| // Point overlap kernel at the ti-th slice of the pre-uploaded input. | |
| // cudarc CudaSlice doesn't have a "view" per se, so we must copy the | |
| // slice into the reusable inp_dev buffer. This is a D2D copy — much | |
| // faster than H2D. | |
| // (Alternative: rewrite kernel to accept an offset; deferred.) | |
| let in_off = ti * input_bits; | |
| // Use dtod_copy via raw slice indexing: cudarc exposes slice() for this. | |
| let sub = inputs_flat_dev.slice(in_off..in_off + input_bits); | |
| self.dev.dtod_copy(&sub, &mut self.inp_dev)?; | |
| // 1. sp_overlap | |
| unsafe { | |
| overlap_fn.clone().launch( | |
| overlap_cfg, | |
| ( | |
| &self.inp_dev, | |
| &self.syn_bit, | |
| &self.syn_perm, | |
| &self.boost, | |
| self.conn_thr, | |
| self.synapses_per_col as u32, | |
| n as u32, | |
| &mut self.raw, | |
| &mut self.boosted, | |
| ), | |
| )?; | |
| } | |
| // 2. Clear active_mask, then sp_topk | |
| self.dev.memset_zeros(&mut self.active_mask)?; | |
| unsafe { | |
| topk_fn.clone().launch( | |
| topk_cfg, | |
| (&self.boosted, n as u32, k as u32, &mut self.active_mask), | |
| )?; | |
| } | |
| // 3. sp_learn | |
| if learn { | |
| unsafe { | |
| learn_fn.clone().launch( | |
| learn_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.inp_dev, | |
| &self.syn_bit, | |
| &mut self.syn_perm, | |
| self.inc, | |
| self.dec, | |
| self.synapses_per_col as u32, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| } | |
| // 4. duty update (device) | |
| unsafe { | |
| duty_fn.clone().launch( | |
| duty_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.raw, | |
| &mut self.active_duty, | |
| &mut self.overlap_duty, | |
| &mut self.boost, | |
| alpha, | |
| 1.0f32, | |
| 0.0f32, | |
| 0.0f32, | |
| 0u32, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| // 5. Boost update. Two modes: | |
| // * strict_parity (tests): host-side exp for bit-exact match. | |
| // * default (production): GPU expf is close enough and ~10x faster | |
| // since we skip the D2H/H2D round-trip. | |
| if learn && self.boost_strength > 0.0 { | |
| if self.strict_parity { | |
| let mut duty_host = vec![0f32; n]; | |
| self.dev | |
| .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?; | |
| let sum: f32 = duty_host.iter().sum(); | |
| let mean = sum / (n as f32); | |
| let mut boost_host = vec![0f32; n]; | |
| for i in 0..n { | |
| boost_host[i] = | |
| (-self.boost_strength * (duty_host[i] - mean)).exp(); | |
| } | |
| self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?; | |
| // Permanence bump (rare). Only evaluated in strict mode. | |
| let mut ov_host = vec![0f32; n]; | |
| self.dev | |
| .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?; | |
| let max_ov = ov_host.iter().cloned().fold(0f32, f32::max); | |
| if max_ov > 0.0 { | |
| let thr = 0.001f32 * max_ov; | |
| let bump = self.inc * 0.1f32; | |
| let bump_cols: Vec<u32> = ov_host | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(i, &o)| { | |
| if o < thr { Some(i as u32) } else { None } | |
| }) | |
| .collect(); | |
| if !bump_cols.is_empty() { | |
| let s = self.synapses_per_col; | |
| let mut perm_host = vec![0f32; n * s]; | |
| self.dev | |
| .dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?; | |
| for &c in &bump_cols { | |
| let base = (c as usize) * s; | |
| for p in &mut perm_host[base..base + s] { | |
| *p = (*p + bump).min(1.0); | |
| } | |
| } | |
| self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?; | |
| } | |
| } | |
| } else { | |
| // Fast path: fused mean + boost = expf(-strength*(ad-mean)) | |
| // in a single GPU block. Zero D2H, zero H2D — fully async. | |
| let boost_fn = self | |
| .dev | |
| .get_func("htm_sp_boost_fused", "sp_boost_from_duty") | |
| .expect("sp_boost_fused not loaded"); | |
| let boost_cfg = LaunchConfig { | |
| grid_dim: (1, 1, 1), | |
| block_dim: (1024, 1, 1), | |
| shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32, | |
| }; | |
| unsafe { | |
| boost_fn.launch( | |
| boost_cfg, | |
| ( | |
| &self.active_duty, | |
| &mut self.boost, | |
| self.boost_strength, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| } | |
| } | |
| // D2H the active_mask for this step. This is the single | |
| // unavoidable sync point per step — CPU TM needs the active | |
| // indices for its next state update. At 2048 bytes / step this | |
| // is tiny in bandwidth but costs a full syncronize (~5-10μs). | |
| self.dev | |
| .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?; | |
| let co = ti * n; | |
| cols_out[co..co + n].copy_from_slice(&self.host_mask); | |
| // Extract active indices. | |
| for (i, &b) in self.host_mask.iter().enumerate() { | |
| if b != 0 { | |
| active_indices_host.push(i as u32); | |
| } | |
| } | |
| // Insert separator (u32::MAX) between steps to demarcate step boundaries. | |
| active_indices_host.push(u32::MAX); | |
| } | |
| Ok(()) | |
| } | |
| /// Fully-on-GPU batched SP + TM. Zero per-step host sync. | |
| /// | |
| /// Inputs: | |
| /// inputs_flat_dev : (T * input_bits) u8 already uploaded | |
| /// cols_dev : (T * n_cols) u8 output — active-column mask per step | |
| /// anom_dev : (T,) f32 output — anomaly score per step | |
| /// tm : persistent GPU TemporalMemory for this region | |
| pub fn step_batch_with_tm( | |
| &mut self, | |
| inputs_flat_dev: &CudaSlice<u8>, | |
| t: usize, | |
| input_bits: usize, | |
| learn: bool, | |
| cols_dev: &mut CudaSlice<u8>, | |
| anom_dev: &mut CudaSlice<f32>, | |
| tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu, | |
| ) -> Result<(), DriverError> { | |
| let n = self.n_columns; | |
| let k = ((self.sparsity * n as f32).round() as usize).max(1); | |
| debug_assert_eq!(cols_dev.len(), t * n); | |
| debug_assert_eq!(anom_dev.len(), t); | |
| let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap(); | |
| let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap(); | |
| let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap(); | |
| let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap(); | |
| let overlap_cfg = LaunchConfig { | |
| grid_dim: (n as u32, 1, 1), | |
| block_dim: (128, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| let topk_cfg = LaunchConfig { | |
| grid_dim: (1, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32, | |
| }; | |
| let learn_cfg = overlap_cfg; | |
| let duty_cfg = LaunchConfig { | |
| grid_dim: ((n as u32 + 255) / 256, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| let alpha = 1.0f32 / self.duty_period.max(1.0); | |
| for ti in 0..t { | |
| let in_off = ti * input_bits; | |
| let sub = inputs_flat_dev.slice(in_off..in_off + input_bits); | |
| self.dev.dtod_copy(&sub, &mut self.inp_dev)?; | |
| // 1. sp_overlap | |
| unsafe { | |
| overlap_fn.clone().launch( | |
| overlap_cfg, | |
| ( | |
| &self.inp_dev, | |
| &self.syn_bit, | |
| &self.syn_perm, | |
| &self.boost, | |
| self.conn_thr, | |
| self.synapses_per_col as u32, | |
| n as u32, | |
| &mut self.raw, | |
| &mut self.boosted, | |
| ), | |
| )?; | |
| } | |
| // 2. clear + sp_topk | |
| self.dev.memset_zeros(&mut self.active_mask)?; | |
| unsafe { | |
| topk_fn.clone().launch( | |
| topk_cfg, | |
| (&self.boosted, n as u32, k as u32, &mut self.active_mask), | |
| )?; | |
| } | |
| // 3. sp_learn | |
| if learn { | |
| unsafe { | |
| learn_fn.clone().launch( | |
| learn_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.inp_dev, | |
| &self.syn_bit, | |
| &mut self.syn_perm, | |
| self.inc, | |
| self.dec, | |
| self.synapses_per_col as u32, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| } | |
| // 4. duty update (stage 1: no-boost write) | |
| unsafe { | |
| duty_fn.clone().launch( | |
| duty_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.raw, | |
| &mut self.active_duty, | |
| &mut self.overlap_duty, | |
| &mut self.boost, | |
| alpha, | |
| 1.0f32, | |
| 0.0f32, | |
| 0.0f32, | |
| 0u32, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| // 5. Boost update: fused GPU kernel (no D2H). | |
| if learn && self.boost_strength > 0.0 { | |
| let boost_fn = self.dev | |
| .get_func("htm_sp_boost_fused", "sp_boost_from_duty") | |
| .expect("sp_boost_fused not loaded"); | |
| let boost_cfg = LaunchConfig { | |
| grid_dim: (1, 1, 1), | |
| block_dim: (1024, 1, 1), | |
| shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32, | |
| }; | |
| unsafe { | |
| boost_fn.launch( | |
| boost_cfg, | |
| ( | |
| &self.active_duty, | |
| &mut self.boost, | |
| self.boost_strength, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| } | |
| // 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n]. | |
| let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n); | |
| self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?; | |
| // 7. GPU TM step: predict + activate + anomaly + learn, all on device. | |
| tm.step(&self.active_mask, anom_dev, ti as u32, learn)?; | |
| } | |
| Ok(()) | |
| } | |
| /// One SP step on the GPU. Returns sorted active-column indices. | |
| pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> { | |
| debug_assert_eq!(input.len(), self.input_bits); | |
| let n = self.n_columns; | |
| let k = ((self.sparsity * n as f32).round() as usize).max(1); | |
| // 1. H2D input SDR. | |
| self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?; | |
| // 2. Launch sp_overlap: grid=n_columns, block=128. | |
| let overlap_fn = self | |
| .dev | |
| .get_func("htm_sp_overlap", "sp_overlap") | |
| .expect("sp_overlap not loaded"); | |
| let overlap_cfg = LaunchConfig { | |
| grid_dim: (n as u32, 1, 1), | |
| block_dim: (128, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| overlap_fn.launch( | |
| overlap_cfg, | |
| ( | |
| &self.inp_dev, | |
| &self.syn_bit, | |
| &self.syn_perm, | |
| &self.boost, | |
| self.conn_thr, | |
| self.synapses_per_col as u32, | |
| n as u32, | |
| &mut self.raw, | |
| &mut self.boosted, | |
| ), | |
| )?; | |
| } | |
| // 3. Launch sp_topk: single block, shared mem = n_columns * f32. | |
| let topk_fn = self | |
| .dev | |
| .get_func("htm_sp_topk", "sp_topk_select") | |
| .expect("sp_topk not loaded"); | |
| let topk_cfg = LaunchConfig { | |
| grid_dim: (1, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32, | |
| }; | |
| // Clear active_mask first. memset_zeros avoids an H2D of a host | |
| // zeroes vector every step. | |
| self.dev.memset_zeros(&mut self.active_mask)?; | |
| unsafe { | |
| topk_fn.launch( | |
| topk_cfg, | |
| ( | |
| &self.boosted, | |
| n as u32, | |
| k as u32, | |
| &mut self.active_mask, | |
| ), | |
| )?; | |
| } | |
| // 4. Optional: sp_learn on active columns. | |
| if learn { | |
| let learn_fn = self | |
| .dev | |
| .get_func("htm_sp_learn", "sp_learn") | |
| .expect("sp_learn not loaded"); | |
| let learn_cfg = LaunchConfig { | |
| grid_dim: (n as u32, 1, 1), | |
| block_dim: (128, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| learn_fn.launch( | |
| learn_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.inp_dev, | |
| &self.syn_bit, | |
| &mut self.syn_perm, | |
| self.inc, | |
| self.dec, | |
| self.synapses_per_col as u32, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| } | |
| // 5. Duty cycle + boost update. Always runs (matches CPU). | |
| // We need mean_duty on the host — compute BEFORE the update (matches | |
| // CPU sp.rs line 200-205 where mean is computed then written). | |
| // Actually CPU computes mean of the PRE-update duty cycles too? Re-read: | |
| // sp.rs lines 186-196 update duty cycles (pre-mean). | |
| // Line 202: mean = sum(active_duty_cycle) / n ← after update. | |
| // Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)). | |
| // So mean is on POST-update values. | |
| // Easiest: 1) run duty update with boost_strength=0 (skip boost calc), | |
| // 2) D2H active_duty, compute mean, 3) run a boost-only kernel | |
| // OR inline the exp() in a second launch with mean passed. | |
| // | |
| // For simplicity and correctness we fuse: run the duty kernel with | |
| // mean=0 and boost_strength=0 (disables boost write), then D2H to | |
| // compute mean, then re-launch with the true mean. Two launches, one | |
| // tiny D2H (n × f32). At n=2048 this is 8KB per step — negligible. | |
| let alpha = 1.0f32 / self.duty_period.max(1.0); | |
| let duty_fn = self | |
| .dev | |
| .get_func("htm_sp_duty", "sp_duty_update") | |
| .expect("sp_duty not loaded"); | |
| let duty_cfg = LaunchConfig { | |
| grid_dim: ((n as u32 + 255) / 256, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| // Stage 1: update duty cycles (boost_strength=0 -> no write). | |
| unsafe { | |
| duty_fn.launch( | |
| duty_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.raw, | |
| &mut self.active_duty, | |
| &mut self.overlap_duty, | |
| &mut self.boost, | |
| alpha, | |
| 1.0f32, // stim_thr | |
| 0.0f32, // boost_strength = 0 -> skip write | |
| 0.0f32, // mean_duty (unused) | |
| 0u32, // learn_flag = 0 | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| if learn && self.boost_strength > 0.0 && self.strict_parity { | |
| // Boost update must bit-match CPU `f32::exp`, so we compute it on | |
| // the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048. | |
| // Critical for learning parity — CUDA expf (even without fast-math) | |
| // uses different rounding for some inputs than host libm. | |
| let mut duty_host = vec![0f32; n]; | |
| self.dev | |
| .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?; | |
| let sum: f32 = duty_host.iter().sum(); | |
| let mean = sum / (n as f32); | |
| let mut boost_host = vec![0f32; n]; | |
| for i in 0..n { | |
| boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp(); | |
| } | |
| self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?; | |
| // CPU sp.rs 210-226: permanence bump for chronically under-stimulated | |
| // columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle), | |
| // add inc*0.1 to every synapse of column i (clamped to 1.0). | |
| // This runs only once per step and only for the rare cases, but we | |
| // need it for bit-exact parity with CPU learn. | |
| let mut ov_host = vec![0f32; n]; | |
| self.dev | |
| .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?; | |
| let max_ov = ov_host.iter().cloned().fold(0f32, f32::max); | |
| if max_ov > 0.0 { | |
| let thr = 0.001f32 * max_ov; | |
| let bump = self.inc * 0.1f32; | |
| // Find columns needing a bump. Usually empty. Rare → D2H/H2D | |
| // of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40). | |
| let bump_cols: Vec<u32> = ov_host | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None }) | |
| .collect(); | |
| if !bump_cols.is_empty() { | |
| // Download, bump, upload. (Keeps implementation simple and | |
| // bit-exact. Could kernelize later.) | |
| let s = self.synapses_per_col; | |
| let mut perm_host = vec![0f32; n * s]; | |
| self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?; | |
| for &c in &bump_cols { | |
| let base = (c as usize) * s; | |
| for p in &mut perm_host[base..base + s] { | |
| *p = (*p + bump).min(1.0); | |
| } | |
| } | |
| self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?; | |
| } | |
| } | |
| } else if learn && self.boost_strength > 0.0 { | |
| // Fast path: GPU-side boost using the already-loaded duty kernel. | |
| let mut duty_host = vec![0f32; n]; | |
| self.dev | |
| .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?; | |
| let sum: f32 = duty_host.iter().sum(); | |
| let mean = sum / (n as f32); | |
| let boost_fn = self | |
| .dev | |
| .get_func("htm_sp_duty", "sp_duty_update") | |
| .expect("sp_duty not loaded"); | |
| unsafe { | |
| boost_fn.launch( | |
| duty_cfg, | |
| ( | |
| &self.active_mask, | |
| &self.raw, | |
| &mut self.active_duty, | |
| &mut self.overlap_duty, | |
| &mut self.boost, | |
| 0.0f32, | |
| 1.0f32, | |
| self.boost_strength, | |
| mean, | |
| 1u32, | |
| n as u32, | |
| ), | |
| )?; | |
| } | |
| } | |
| // 6. D2H active_mask and convert to sorted index list. | |
| self.dev | |
| .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?; | |
| let mut active: Vec<u32> = Vec::with_capacity(k); | |
| for (i, &b) in self.host_mask.iter().enumerate() { | |
| if b != 0 { | |
| active.push(i as u32); | |
| } | |
| } | |
| debug_assert_eq!(active.len(), k, "SP must emit exactly k winners"); | |
| Ok(active) | |
| } | |
| } | |