Spaces:
Runtime error
Runtime error
| //! Fused HTM megakernel launcher. | |
| //! | |
| //! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into | |
| //! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for | |
| //! the kernel design and the cross-block coherence strategy (grid barrier | |
| //! via device counter with all blocks concurrently resident). | |
| //! | |
| //! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code | |
| //! probes the device SM count at construction and caps grid_dim.x | |
| //! accordingly — otherwise the grid barrier deadlocks. | |
| //! | |
| //! Semantic change from the top-K pipeline: activation is per-column | |
| //! threshold-based (local lateral inhibition) instead of global top-K. | |
| //! A per-column `inhibition_threshold` is tracked and EMA-steered to hit | |
| //! the sparsity target. This is a real architectural change and is | |
| //! documented in `docs/GPU_HTM.md`. | |
| use std::ffi::CString; | |
| use std::sync::Arc; | |
| use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError, | |
| LaunchConfig}; | |
| use cudarc::nvrtc::Ptx; | |
| use super::sp_gpu::SpatialPoolerGpu; | |
| use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT}; | |
| const PTX_HTM_FUSED: &str = | |
| include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx")); | |
| /// Struct-by-value pointer pack — matches C-side `FusedPtrs`. | |
| /// | |
| /// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The | |
| /// C-side `FusedPtrs` still has the field at the same byte offset; removing | |
| /// it here would shift all subsequent fields and break the layout. Worker A | |
| /// will eventually delete the field from both sides once the kernel is | |
| /// updated; until then we zero it. | |
| pub struct FusedPtrs { | |
| pub syn_bit: u64, | |
| pub syn_perm: u64, | |
| pub boost: u64, | |
| pub active_duty: u64, | |
| pub inhibition_threshold: u64, | |
| pub seg_cell_id: u64, | |
| pub seg_syn_count: u64, | |
| pub syn_presyn: u64, | |
| pub tm_syn_perm: u64, | |
| pub cell_seg_count: u64, | |
| pub cell_active_a: u64, | |
| pub cell_active_b: u64, | |
| pub cell_winner_a: u64, | |
| pub cell_winner_b: u64, | |
| pub inputs: u64, | |
| pub cols_out: u64, | |
| pub anom_out: u64, | |
| /// ABI-compat dummy — always 0. No device memory is allocated for this | |
| /// field; the cluster barrier replaces the old software DLB barrier. | |
| pub barrier_counters: u64, | |
| pub step_scratch: u64, | |
| } | |
| unsafe impl DeviceRepr for FusedPtrs {} | |
| /// Launch-time config — matches C-side `FusedConfig` 1:1. | |
| pub struct FusedConfig { | |
| pub input_bits: u32, | |
| pub n_columns: u32, | |
| pub synapses_per_col: u32, | |
| pub conn_thr: f32, | |
| pub sp_inc: f32, | |
| pub sp_dec: f32, | |
| pub sparsity_target: f32, | |
| pub duty_alpha: f32, | |
| pub thr_adapt_rate: f32, | |
| pub cells_per_column: u32, | |
| pub n_cells: u32, | |
| pub bits_words: u32, | |
| pub max_segments_per_cell: u32, | |
| pub synapses_per_segment: u32, | |
| pub activation_threshold: u32, | |
| pub learning_threshold: u32, | |
| pub max_new_synapses: u32, | |
| pub conn_thr_i16: i32, | |
| pub perm_inc_i16: i32, | |
| pub perm_dec_i16: i32, | |
| pub predicted_seg_dec_i16: i32, | |
| pub initial_perm_i16: i32, | |
| pub t: u32, | |
| pub learn: u32, | |
| pub iter_seed: u32, | |
| pub cooperative_grid_sync: u32, | |
| } | |
| unsafe impl DeviceRepr for FusedConfig {} | |
| /// Cluster launch parameters probed at construction time. | |
| pub(crate) struct ClusterInfo { | |
| /// Maximum cluster size supported by this device (0 = cluster unsupported). | |
| pub max_cluster_size: u32, | |
| } | |
| // There is only ONE launch mode: non-cooperative launch with Hopper Thread | |
| // Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old | |
| // software DLB barrier and the cooperative-launch path are both removed. | |
| // Cluster barriers replace both. | |
| pub(crate) struct FusedLaunchPlan { | |
| pub grid_dim_x: u32, | |
| pub block_dim_x: u32, | |
| pub cooperative_grid_limit: u32, | |
| pub sm_count: u32, | |
| } | |
| fn fused_grid_cap_override() -> Option<u32> { | |
| std::env::var("HTM_FUSED_GRID_CAP") | |
| .ok() | |
| .and_then(|s| s.parse::<u32>().ok()) | |
| .map(|v| v.max(1)) | |
| } | |
| pub(crate) fn plan_fused_launch( | |
| sm_count: u32, | |
| cooperative_supported: bool, | |
| cooperative_grid_limit: u32, | |
| grid_cap_override: Option<u32>, | |
| ) -> Result<FusedLaunchPlan, String> { | |
| let sm_count = sm_count.max(1); | |
| // 1024 threads/block exceeds the register file on Ampere (sm_86: 65536 | |
| // regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives | |
| // 256 regs/thread which is ample. Compensate with more blocks via | |
| // cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline), | |
| // 1024 works fine, but 256 is safe everywhere. | |
| let block_dim_x = 256u32; | |
| // Cluster launch path: cooperative launch is not required. Keep the probe | |
| // result for residency estimation only. | |
| if !cooperative_supported { | |
| eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only."); | |
| } | |
| // Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins). | |
| // Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost. | |
| let default_grid_cap = 16u32; | |
| let grid_cap = grid_cap_override.unwrap_or(default_grid_cap); | |
| let resident_bound = if cooperative_grid_limit > 0 { | |
| cooperative_grid_limit.max(sm_count * 2) | |
| } else { | |
| sm_count * 2 | |
| }; | |
| Ok(FusedLaunchPlan { | |
| grid_dim_x: resident_bound.min(grid_cap).max(1), | |
| block_dim_x, | |
| cooperative_grid_limit: resident_bound, | |
| sm_count, | |
| }) | |
| } | |
| pub(super) struct RawFusedKernel { | |
| module: sys::CUmodule, | |
| pub(super) function: sys::CUfunction, | |
| pub(super) function_batched: sys::CUfunction, | |
| } | |
| unsafe impl Send for RawFusedKernel {} | |
| unsafe impl Sync for RawFusedKernel {} | |
| impl Drop for RawFusedKernel { | |
| fn drop(&mut self) { | |
| unsafe { | |
| let _ = result::module::unload(self.module); | |
| } | |
| } | |
| } | |
| /// Owns fused-path-only device state: | |
| /// - per-column inhibition threshold (replaces global top-K) | |
| /// - ping-pong cell_active/cell_winner bitsets | |
| /// - step_scratch (n_active, n_unpred per timestep) | |
| /// - cluster launch capability info | |
| pub struct FusedState { | |
| dev: Arc<CudaDevice>, | |
| pub(super) raw_kernel: RawFusedKernel, | |
| pub inhibition_threshold: CudaSlice<f32>, | |
| pub cell_active_bits_a: CudaSlice<u32>, | |
| pub cell_active_bits_b: CudaSlice<u32>, | |
| pub cell_winner_bits_a: CudaSlice<u32>, | |
| pub cell_winner_bits_b: CudaSlice<u32>, | |
| pub step_scratch: CudaSlice<u32>, // length 6 | |
| pub grid_dim_x: u32, | |
| pub block_dim_x: u32, | |
| pub cooperative_grid_limit: u32, | |
| pub iter_counter: u32, | |
| /// Hopper cluster launch capability (0 = unsupported). | |
| pub cluster_info: ClusterInfo, | |
| // Config mirror (read-only after init). | |
| pub initial_threshold: f32, | |
| } | |
| impl FusedState { | |
| pub fn new( | |
| dev: Arc<CudaDevice>, | |
| n_columns: usize, | |
| cells_per_column: usize, | |
| initial_threshold: f32, | |
| ) -> Result<Self, DriverError> { | |
| let n_cells = n_columns * cells_per_column; | |
| assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets"); | |
| let bits_words = n_cells / 32; | |
| let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?; | |
| let init_vec = vec![initial_threshold; n_columns]; | |
| dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?; | |
| let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?; | |
| let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?; | |
| let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?; | |
| let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?; | |
| let step_scratch = dev.alloc_zeros::<u32>(6)?; | |
| unsafe { | |
| result::ctx::set_current(*dev.cu_primary_ctx())?; | |
| } | |
| if dev.get_func("htm_fused", "htm_fused_step").is_none() { | |
| dev.load_ptx( | |
| Ptx::from_src(PTX_HTM_FUSED), | |
| "htm_fused", | |
| &["htm_fused_step", "htm_fused_step_batched"], | |
| )?; | |
| } | |
| let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes"); | |
| let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?; | |
| let function = unsafe { | |
| result::module::get_function(module, CString::new("htm_fused_step").unwrap()) | |
| }?; | |
| let function_batched = unsafe { | |
| result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap()) | |
| }?; | |
| // Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in). | |
| // Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on | |
| // every launched kernel function, otherwise cuLaunchKernelEx rejects | |
| // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE. | |
| unsafe { | |
| let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED; | |
| // Ignore errors: older CUDA may lack the attribute, in which case | |
| // only portable sizes (<= 8) work — plan_fused_launch caps at 8. | |
| let _ = sys::lib().cuFuncSetAttribute(function, attr, 1); | |
| let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1); | |
| } | |
| // Probe SM count. | |
| let sm_count = match dev.attribute( | |
| cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, | |
| ) { | |
| Ok(v) => v as u32, | |
| Err(_) => 16u32, | |
| }; | |
| // T1: Probe Hopper cluster launch capability. | |
| let max_cluster_size = match dev.attribute( | |
| cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH, | |
| ) { | |
| Ok(v) if v > 0 => { | |
| // H200/sm_90a supports up to 16 blocks per cluster. | |
| // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the | |
| // Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster). | |
| 16u32 | |
| } | |
| _ => 0u32, | |
| }; | |
| eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size); | |
| let cluster_info = ClusterInfo { max_cluster_size }; | |
| let cooperative_supported = matches!( | |
| dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH), | |
| Ok(v) if v > 0 | |
| ); | |
| let cooperative_grid_limit = if cooperative_supported { | |
| let blocks_per_sm = unsafe { | |
| result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0) | |
| } | |
| .ok() | |
| .map(|v| v.max(0) as u32) | |
| .unwrap_or(0); | |
| sm_count.saturating_mul(blocks_per_sm) | |
| } else { | |
| 0 | |
| }; | |
| let launch_plan = plan_fused_launch( | |
| sm_count, | |
| cooperative_supported, | |
| cooperative_grid_limit, | |
| fused_grid_cap_override(), | |
| ) | |
| .map_err(|msg| { | |
| // Surface as a CUDA-ish error so callers can propagate. | |
| eprintln!("[htm_rust] FATAL: {msg}"); | |
| DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED) | |
| })?; | |
| eprintln!( | |
| "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}", | |
| launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit, | |
| cluster_info.max_cluster_size, | |
| ); | |
| Ok(Self { | |
| dev, | |
| raw_kernel: RawFusedKernel { module, function, function_batched }, | |
| inhibition_threshold, | |
| cell_active_bits_a, | |
| cell_active_bits_b, | |
| cell_winner_bits_a, | |
| cell_winner_bits_b, | |
| step_scratch, | |
| grid_dim_x: launch_plan.grid_dim_x, | |
| block_dim_x: launch_plan.block_dim_x, | |
| cooperative_grid_limit: launch_plan.cooperative_grid_limit, | |
| iter_counter: 0, | |
| cluster_info, | |
| initial_threshold, | |
| }) | |
| } | |
| /// Reset fused state. Called at region.reset(). | |
| pub fn reset(&mut self) -> Result<(), DriverError> { | |
| self.dev.memset_zeros(&mut self.cell_active_bits_a)?; | |
| self.dev.memset_zeros(&mut self.cell_active_bits_b)?; | |
| self.dev.memset_zeros(&mut self.cell_winner_bits_a)?; | |
| self.dev.memset_zeros(&mut self.cell_winner_bits_b)?; | |
| self.dev.memset_zeros(&mut self.step_scratch)?; | |
| // Do NOT reset inhibition_threshold — it's learned state. A hard | |
| // reset of TM state should NOT forget the sparsity calibration. | |
| Ok(()) | |
| } | |
| } | |
| /// Launch the fused megakernel. Processes all T timesteps in one kernel. | |
| /// | |
| /// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)` | |
| /// when the device supports cluster launch, otherwise falls back to a plain | |
| /// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the | |
| /// entire grid fits in one cluster. | |
| pub fn launch_fused( | |
| sp: &mut SpatialPoolerGpu, | |
| tm: &mut TemporalMemoryGpu, | |
| fused: &mut FusedState, | |
| inputs_flat: &CudaSlice<u8>, | |
| cols_out: &mut CudaSlice<u8>, | |
| anom_out: &mut CudaSlice<f32>, | |
| t: usize, | |
| input_bits: usize, | |
| learn: bool, | |
| ) -> Result<(), DriverError> { | |
| // Reset step_scratch before each launch (safe re-entry). | |
| sp.dev_ref().memset_zeros(&mut fused.step_scratch)?; | |
| fused.iter_counter = fused.iter_counter.wrapping_add(1); | |
| let cfg = FusedConfig { | |
| input_bits: input_bits as u32, | |
| n_columns: sp.n_columns_accessor() as u32, | |
| synapses_per_col: sp.synapses_per_col_accessor() as u32, | |
| conn_thr: sp.conn_thr_accessor(), | |
| sp_inc: sp.inc_accessor(), | |
| sp_dec: sp.dec_accessor(), | |
| sparsity_target: sp.sparsity_accessor(), | |
| duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0), | |
| thr_adapt_rate: 0.001f32, | |
| cells_per_column: tm.cells_per_column as u32, | |
| n_cells: tm.n_cells as u32, | |
| bits_words: tm.bits_words as u32, | |
| max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32, | |
| synapses_per_segment: MAX_SYN_PER_SEGMENT as u32, | |
| activation_threshold: tm.activation_threshold, | |
| learning_threshold: tm.learning_threshold, | |
| max_new_synapses: tm.max_new_synapse_count, | |
| conn_thr_i16: tm.conn_thr_i16 as i32, | |
| perm_inc_i16: tm.perm_inc_i16 as i32, | |
| perm_dec_i16: tm.perm_dec_i16 as i32, | |
| predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32, | |
| initial_perm_i16: tm.initial_perm_i16 as i32, | |
| t: t as u32, | |
| learn: if learn { 1 } else { 0 }, | |
| iter_seed: fused.iter_counter, | |
| cooperative_grid_sync: 1, | |
| }; | |
| let ptrs = FusedPtrs { | |
| syn_bit: *sp.syn_bit_accessor().device_ptr(), | |
| syn_perm: *sp.syn_perm_accessor().device_ptr(), | |
| boost: *sp.boost_accessor().device_ptr(), | |
| active_duty: *sp.active_duty_accessor().device_ptr(), | |
| inhibition_threshold: *fused.inhibition_threshold.device_ptr(), | |
| seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(), | |
| seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(), | |
| syn_presyn: *tm.syn_presyn_accessor().device_ptr(), | |
| tm_syn_perm: *tm.syn_perm_accessor().device_ptr(), | |
| cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(), | |
| cell_active_a: *fused.cell_active_bits_a.device_ptr(), | |
| cell_active_b: *fused.cell_active_bits_b.device_ptr(), | |
| cell_winner_a: *fused.cell_winner_bits_a.device_ptr(), | |
| cell_winner_b: *fused.cell_winner_bits_b.device_ptr(), | |
| inputs: *inputs_flat.device_ptr(), | |
| cols_out: *cols_out.device_ptr(), | |
| anom_out: *anom_out.device_ptr(), | |
| barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB. | |
| step_scratch: *fused.step_scratch.device_ptr(), | |
| }; | |
| let grid_x = fused.grid_dim_x; | |
| let block_x = fused.block_dim_x; | |
| let cu_stream = *sp.dev_ref().cu_stream(); | |
| let use_cluster = fused.cluster_info.max_cluster_size > 0; | |
| unsafe { | |
| result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?; | |
| let mut kernel_params: [*mut std::ffi::c_void; 2] = [ | |
| (&ptrs as *const FusedPtrs).cast_mut().cast(), | |
| (&cfg as *const FusedConfig).cast_mut().cast(), | |
| ]; | |
| if use_cluster { | |
| // T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION. | |
| // cluster_dim=(16,1,1) maps the entire single-region grid into one cluster. | |
| let mut attr: sys::CUlaunchAttribute = std::mem::zeroed(); | |
| attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; | |
| attr.value.clusterDim.x = 16; | |
| attr.value.clusterDim.y = 1; | |
| attr.value.clusterDim.z = 1; | |
| let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed(); | |
| launch_cfg.gridDimX = grid_x; | |
| launch_cfg.gridDimY = 1; | |
| launch_cfg.gridDimZ = 1; | |
| launch_cfg.blockDimX = block_x; | |
| launch_cfg.blockDimY = 1; | |
| launch_cfg.blockDimZ = 1; | |
| launch_cfg.sharedMemBytes = 0; | |
| launch_cfg.hStream = cu_stream; | |
| launch_cfg.numAttrs = 1; | |
| launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute; | |
| let ret = sys::lib().cuLaunchKernelEx( | |
| &launch_cfg as *const sys::CUlaunchConfig, | |
| fused.raw_kernel.function, | |
| kernel_params.as_mut_ptr(), | |
| std::ptr::null_mut(), | |
| ); | |
| if ret != sys::CUresult::CUDA_SUCCESS { | |
| return Err(DriverError(ret)); | |
| } | |
| } else { | |
| // Pre-Hopper: cooperative kernel launch. The fused kernel uses | |
| // grid.sync() for cross-block synchronization which REQUIRES | |
| // cuLaunchCooperativeKernel (normal launch silently crashes on | |
| // the first grid.sync() call). | |
| let ret = sys::lib().cuLaunchCooperativeKernel( | |
| fused.raw_kernel.function, | |
| grid_x, 1, 1, | |
| block_x, 1, 1, | |
| 0, // sharedMemBytes | |
| cu_stream, | |
| kernel_params.as_mut_ptr(), | |
| ); | |
| if ret != sys::CUresult::CUDA_SUCCESS { | |
| return Err(DriverError(ret)); | |
| } | |
| } | |
| } | |
| Ok(()) | |
| } | |
| /// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel | |
| /// body; each block reads its region's FusedPtrs from a device-side array | |
| /// indexed by blockIdx.y. All regions share the same config (same | |
| /// input_bits/n_columns/etc.) so we pass one FusedConfig. | |
| /// | |
| /// This breaks through the CUDA cooperative-kernel device-level | |
| /// serialization: multiple cooperative launches are serialized regardless | |
| /// of stream, but one cooperative launch with grid.y=B processes all | |
| /// regions in a single invocation — ~B× speedup vs B sequential launches. | |
| /// Low-level raw-pointer entry, called by PyO3 binding which holds the | |
| /// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live, | |
| /// uniquely-borrowed region. All regions must be distinct. | |
| pub(super) fn launch_fused_batched_raw( | |
| region_ptrs: &[*mut super::HTMRegionGpu], | |
| inputs_per_region: &[u64], | |
| cols_per_region: &[u64], | |
| anom_per_region: &[u64], | |
| t: usize, | |
| input_bits: usize, | |
| learn: bool, | |
| ) -> Result<(), DriverError> { | |
| let b = region_ptrs.len(); | |
| assert_eq!(inputs_per_region.len(), b); | |
| assert_eq!(cols_per_region.len(), b); | |
| assert_eq!(anom_per_region.len(), b); | |
| assert!(b >= 1, "need at least one region"); | |
| // Reset per-region step_scratch before each launch. | |
| for &rp in region_ptrs.iter() { | |
| let r = unsafe { &mut *rp }; | |
| let dev = r.sp_gpu.dev_ref().clone(); | |
| dev.memset_zeros(&mut r.fused_state.step_scratch)?; | |
| r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1); | |
| } | |
| // Shared config — all regions use identical sp/tm parameters. | |
| let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = { | |
| let r0 = unsafe { &*region_ptrs[0] }; | |
| ( | |
| r0.fused_state.grid_dim_x, | |
| r0.fused_state.block_dim_x, | |
| r0.fused_state.raw_kernel.function_batched, | |
| *r0.sp_gpu.dev_ref().cu_stream(), | |
| *r0.sp_gpu.dev_ref().cu_primary_ctx(), | |
| ) | |
| }; | |
| let cfg = { | |
| let r = unsafe { &*region_ptrs[0] }; | |
| FusedConfig { | |
| input_bits: input_bits as u32, | |
| n_columns: r.sp_gpu.n_columns_accessor() as u32, | |
| synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32, | |
| conn_thr: r.sp_gpu.conn_thr_accessor(), | |
| sp_inc: r.sp_gpu.inc_accessor(), | |
| sp_dec: r.sp_gpu.dec_accessor(), | |
| sparsity_target: r.sp_gpu.sparsity_accessor(), | |
| duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0), | |
| thr_adapt_rate: 0.001f32, | |
| cells_per_column: r.tm_gpu.cells_per_column as u32, | |
| n_cells: r.tm_gpu.n_cells as u32, | |
| bits_words: r.tm_gpu.bits_words as u32, | |
| max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32, | |
| synapses_per_segment: MAX_SYN_PER_SEGMENT as u32, | |
| activation_threshold: r.tm_gpu.activation_threshold, | |
| learning_threshold: r.tm_gpu.learning_threshold, | |
| max_new_synapses: r.tm_gpu.max_new_synapse_count, | |
| conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32, | |
| perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32, | |
| perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32, | |
| predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32, | |
| initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32, | |
| t: t as u32, | |
| learn: if learn { 1 } else { 0 }, | |
| iter_seed: r.fused_state.iter_counter, | |
| cooperative_grid_sync: 1, | |
| } | |
| }; | |
| // Build B FusedPtrs per-region. | |
| let ptrs_vec: Vec<FusedPtrs> = (0..b) | |
| .map(|i| { | |
| let r = unsafe { &*region_ptrs[i] }; | |
| FusedPtrs { | |
| syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(), | |
| syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(), | |
| boost: *r.sp_gpu.boost_accessor().device_ptr(), | |
| active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(), | |
| inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(), | |
| seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(), | |
| seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(), | |
| syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(), | |
| tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(), | |
| cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(), | |
| cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(), | |
| cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(), | |
| cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(), | |
| cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(), | |
| inputs: inputs_per_region[i], | |
| cols_out: cols_per_region[i], | |
| anom_out: anom_per_region[i], | |
| barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB. | |
| step_scratch: *r.fused_state.step_scratch.device_ptr(), | |
| } | |
| }) | |
| .collect(); | |
| // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes). | |
| // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it. | |
| let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone(); | |
| let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?; | |
| let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr(); | |
| // T10: Cluster launch for batched regions. | |
| // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice) | |
| // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently | |
| // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs). | |
| let use_cluster = { | |
| let r0 = unsafe { &*region_ptrs[0] }; | |
| r0.fused_state.cluster_info.max_cluster_size > 0 | |
| }; | |
| unsafe { | |
| result::ctx::set_current(cu_ctx)?; | |
| let mut kernel_params: [*mut std::ffi::c_void; 2] = [ | |
| (&ptrs_dev_ptr as *const u64).cast_mut().cast(), | |
| (&cfg as *const FusedConfig).cast_mut().cast(), | |
| ]; | |
| if use_cluster { | |
| let mut attr: sys::CUlaunchAttribute = std::mem::zeroed(); | |
| attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; | |
| attr.value.clusterDim.x = 16; | |
| attr.value.clusterDim.y = 1; | |
| attr.value.clusterDim.z = 1; | |
| let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed(); | |
| launch_cfg.gridDimX = grid_x; | |
| launch_cfg.gridDimY = b as u32; | |
| launch_cfg.gridDimZ = 1; | |
| launch_cfg.blockDimX = block_x; | |
| launch_cfg.blockDimY = 1; | |
| launch_cfg.blockDimZ = 1; | |
| launch_cfg.sharedMemBytes = 0; | |
| launch_cfg.hStream = cu_stream; | |
| launch_cfg.numAttrs = 1; | |
| launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute; | |
| let ret = sys::lib().cuLaunchKernelEx( | |
| &launch_cfg as *const sys::CUlaunchConfig, | |
| function_batched, | |
| kernel_params.as_mut_ptr(), | |
| std::ptr::null_mut(), | |
| ); | |
| if ret != sys::CUresult::CUDA_SUCCESS { | |
| return Err(DriverError(ret)); | |
| } | |
| } else { | |
| // Pre-Hopper: cooperative kernel launch (grid.sync() requires it). | |
| let ret = sys::lib().cuLaunchCooperativeKernel( | |
| function_batched, | |
| grid_x, b as u32, 1, | |
| block_x, 1, 1, | |
| 0, // sharedMemBytes | |
| cu_stream, | |
| kernel_params.as_mut_ptr(), | |
| ); | |
| if ret != sys::CUresult::CUDA_SUCCESS { | |
| return Err(DriverError(ret)); | |
| } | |
| } | |
| } | |
| Ok(()) | |
| } | |