//! Fused HTM megakernel launcher. //! //! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into //! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for //! the kernel design and the cross-block coherence strategy (grid barrier //! via device counter with all blocks concurrently resident). //! //! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code //! probes the device SM count at construction and caps grid_dim.x //! accordingly — otherwise the grid barrier deadlocks. //! //! Semantic change from the top-K pipeline: activation is per-column //! threshold-based (local lateral inhibition) instead of global top-K. //! A per-column `inhibition_threshold` is tracked and EMA-steered to hit //! the sparsity target. This is a real architectural change and is //! documented in `docs/GPU_HTM.md`. #![cfg(feature = "gpu")] use std::ffi::CString; use std::sync::Arc; use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError, LaunchConfig}; use cudarc::nvrtc::Ptx; use super::sp_gpu::SpatialPoolerGpu; use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT}; const PTX_HTM_FUSED: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx")); /// Struct-by-value pointer pack — matches C-side `FusedPtrs`. /// /// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The /// C-side `FusedPtrs` still has the field at the same byte offset; removing /// it here would shift all subsequent fields and break the layout. Worker A /// will eventually delete the field from both sides once the kernel is /// updated; until then we zero it. #[repr(C)] #[derive(Clone, Copy)] pub struct FusedPtrs { pub syn_bit: u64, pub syn_perm: u64, pub boost: u64, pub active_duty: u64, pub inhibition_threshold: u64, pub seg_cell_id: u64, pub seg_syn_count: u64, pub syn_presyn: u64, pub tm_syn_perm: u64, pub cell_seg_count: u64, pub cell_active_a: u64, pub cell_active_b: u64, pub cell_winner_a: u64, pub cell_winner_b: u64, pub inputs: u64, pub cols_out: u64, pub anom_out: u64, /// ABI-compat dummy — always 0. No device memory is allocated for this /// field; the cluster barrier replaces the old software DLB barrier. pub barrier_counters: u64, pub step_scratch: u64, } unsafe impl DeviceRepr for FusedPtrs {} /// Launch-time config — matches C-side `FusedConfig` 1:1. #[repr(C)] #[derive(Clone, Copy)] pub struct FusedConfig { pub input_bits: u32, pub n_columns: u32, pub synapses_per_col: u32, pub conn_thr: f32, pub sp_inc: f32, pub sp_dec: f32, pub sparsity_target: f32, pub duty_alpha: f32, pub thr_adapt_rate: f32, pub cells_per_column: u32, pub n_cells: u32, pub bits_words: u32, pub max_segments_per_cell: u32, pub synapses_per_segment: u32, pub activation_threshold: u32, pub learning_threshold: u32, pub max_new_synapses: u32, pub conn_thr_i16: i32, pub perm_inc_i16: i32, pub perm_dec_i16: i32, pub predicted_seg_dec_i16: i32, pub initial_perm_i16: i32, pub t: u32, pub learn: u32, pub iter_seed: u32, pub cooperative_grid_sync: u32, } unsafe impl DeviceRepr for FusedConfig {} /// Cluster launch parameters probed at construction time. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) struct ClusterInfo { /// Maximum cluster size supported by this device (0 = cluster unsupported). pub max_cluster_size: u32, } // There is only ONE launch mode: non-cooperative launch with Hopper Thread // Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old // software DLB barrier and the cooperative-launch path are both removed. // Cluster barriers replace both. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) struct FusedLaunchPlan { pub grid_dim_x: u32, pub block_dim_x: u32, pub cooperative_grid_limit: u32, pub sm_count: u32, } fn fused_grid_cap_override() -> Option { std::env::var("HTM_FUSED_GRID_CAP") .ok() .and_then(|s| s.parse::().ok()) .map(|v| v.max(1)) } pub(crate) fn plan_fused_launch( sm_count: u32, cooperative_supported: bool, cooperative_grid_limit: u32, grid_cap_override: Option, ) -> Result { let sm_count = sm_count.max(1); // 1024 threads/block exceeds the register file on Ampere (sm_86: 65536 // regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives // 256 regs/thread which is ample. Compensate with more blocks via // cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline), // 1024 works fine, but 256 is safe everywhere. let block_dim_x = 256u32; // Cluster launch path: cooperative launch is not required. Keep the probe // result for residency estimation only. if !cooperative_supported { eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only."); } // Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins). // Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost. let default_grid_cap = 16u32; let grid_cap = grid_cap_override.unwrap_or(default_grid_cap); let resident_bound = if cooperative_grid_limit > 0 { cooperative_grid_limit.max(sm_count * 2) } else { sm_count * 2 }; Ok(FusedLaunchPlan { grid_dim_x: resident_bound.min(grid_cap).max(1), block_dim_x, cooperative_grid_limit: resident_bound, sm_count, }) } pub(super) struct RawFusedKernel { module: sys::CUmodule, pub(super) function: sys::CUfunction, pub(super) function_batched: sys::CUfunction, } unsafe impl Send for RawFusedKernel {} unsafe impl Sync for RawFusedKernel {} impl Drop for RawFusedKernel { fn drop(&mut self) { unsafe { let _ = result::module::unload(self.module); } } } /// Owns fused-path-only device state: /// - per-column inhibition threshold (replaces global top-K) /// - ping-pong cell_active/cell_winner bitsets /// - step_scratch (n_active, n_unpred per timestep) /// - cluster launch capability info pub struct FusedState { dev: Arc, pub(super) raw_kernel: RawFusedKernel, pub inhibition_threshold: CudaSlice, pub cell_active_bits_a: CudaSlice, pub cell_active_bits_b: CudaSlice, pub cell_winner_bits_a: CudaSlice, pub cell_winner_bits_b: CudaSlice, pub step_scratch: CudaSlice, // length 6 pub grid_dim_x: u32, pub block_dim_x: u32, pub cooperative_grid_limit: u32, pub iter_counter: u32, /// Hopper cluster launch capability (0 = unsupported). pub cluster_info: ClusterInfo, // Config mirror (read-only after init). #[allow(dead_code)] pub initial_threshold: f32, } impl FusedState { pub fn new( dev: Arc, n_columns: usize, cells_per_column: usize, initial_threshold: f32, ) -> Result { let n_cells = n_columns * cells_per_column; assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets"); let bits_words = n_cells / 32; let mut inhibition_threshold = dev.alloc_zeros::(n_columns)?; let init_vec = vec![initial_threshold; n_columns]; dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?; let cell_active_bits_a = dev.alloc_zeros::(bits_words)?; let cell_active_bits_b = dev.alloc_zeros::(bits_words)?; let cell_winner_bits_a = dev.alloc_zeros::(bits_words)?; let cell_winner_bits_b = dev.alloc_zeros::(bits_words)?; let step_scratch = dev.alloc_zeros::(6)?; unsafe { result::ctx::set_current(*dev.cu_primary_ctx())?; } if dev.get_func("htm_fused", "htm_fused_step").is_none() { dev.load_ptx( Ptx::from_src(PTX_HTM_FUSED), "htm_fused", &["htm_fused_step", "htm_fused_step_batched"], )?; } let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes"); let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?; let function = unsafe { result::module::get_function(module, CString::new("htm_fused_step").unwrap()) }?; let function_batched = unsafe { result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap()) }?; // Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in). // Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on // every launched kernel function, otherwise cuLaunchKernelEx rejects // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE. unsafe { let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED; // Ignore errors: older CUDA may lack the attribute, in which case // only portable sizes (<= 8) work — plan_fused_launch caps at 8. let _ = sys::lib().cuFuncSetAttribute(function, attr, 1); let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1); } // Probe SM count. let sm_count = match dev.attribute( cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, ) { Ok(v) => v as u32, Err(_) => 16u32, }; // T1: Probe Hopper cluster launch capability. let max_cluster_size = match dev.attribute( cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH, ) { Ok(v) if v > 0 => { // H200/sm_90a supports up to 16 blocks per cluster. // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the // Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster). 16u32 } _ => 0u32, }; eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size); let cluster_info = ClusterInfo { max_cluster_size }; let cooperative_supported = matches!( dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH), Ok(v) if v > 0 ); let cooperative_grid_limit = if cooperative_supported { let blocks_per_sm = unsafe { result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0) } .ok() .map(|v| v.max(0) as u32) .unwrap_or(0); sm_count.saturating_mul(blocks_per_sm) } else { 0 }; let launch_plan = plan_fused_launch( sm_count, cooperative_supported, cooperative_grid_limit, fused_grid_cap_override(), ) .map_err(|msg| { // Surface as a CUDA-ish error so callers can propagate. eprintln!("[htm_rust] FATAL: {msg}"); DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED) })?; eprintln!( "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}", launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit, cluster_info.max_cluster_size, ); Ok(Self { dev, raw_kernel: RawFusedKernel { module, function, function_batched }, inhibition_threshold, cell_active_bits_a, cell_active_bits_b, cell_winner_bits_a, cell_winner_bits_b, step_scratch, grid_dim_x: launch_plan.grid_dim_x, block_dim_x: launch_plan.block_dim_x, cooperative_grid_limit: launch_plan.cooperative_grid_limit, iter_counter: 0, cluster_info, initial_threshold, }) } /// Reset fused state. Called at region.reset(). pub fn reset(&mut self) -> Result<(), DriverError> { self.dev.memset_zeros(&mut self.cell_active_bits_a)?; self.dev.memset_zeros(&mut self.cell_active_bits_b)?; self.dev.memset_zeros(&mut self.cell_winner_bits_a)?; self.dev.memset_zeros(&mut self.cell_winner_bits_b)?; self.dev.memset_zeros(&mut self.step_scratch)?; // Do NOT reset inhibition_threshold — it's learned state. A hard // reset of TM state should NOT forget the sparsity calibration. Ok(()) } } /// Launch the fused megakernel. Processes all T timesteps in one kernel. /// /// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)` /// when the device supports cluster launch, otherwise falls back to a plain /// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the /// entire grid fits in one cluster. #[allow(clippy::too_many_arguments)] pub fn launch_fused( sp: &mut SpatialPoolerGpu, tm: &mut TemporalMemoryGpu, fused: &mut FusedState, inputs_flat: &CudaSlice, cols_out: &mut CudaSlice, anom_out: &mut CudaSlice, t: usize, input_bits: usize, learn: bool, ) -> Result<(), DriverError> { // Reset step_scratch before each launch (safe re-entry). sp.dev_ref().memset_zeros(&mut fused.step_scratch)?; fused.iter_counter = fused.iter_counter.wrapping_add(1); let cfg = FusedConfig { input_bits: input_bits as u32, n_columns: sp.n_columns_accessor() as u32, synapses_per_col: sp.synapses_per_col_accessor() as u32, conn_thr: sp.conn_thr_accessor(), sp_inc: sp.inc_accessor(), sp_dec: sp.dec_accessor(), sparsity_target: sp.sparsity_accessor(), duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0), thr_adapt_rate: 0.001f32, cells_per_column: tm.cells_per_column as u32, n_cells: tm.n_cells as u32, bits_words: tm.bits_words as u32, max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32, synapses_per_segment: MAX_SYN_PER_SEGMENT as u32, activation_threshold: tm.activation_threshold, learning_threshold: tm.learning_threshold, max_new_synapses: tm.max_new_synapse_count, conn_thr_i16: tm.conn_thr_i16 as i32, perm_inc_i16: tm.perm_inc_i16 as i32, perm_dec_i16: tm.perm_dec_i16 as i32, predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32, initial_perm_i16: tm.initial_perm_i16 as i32, t: t as u32, learn: if learn { 1 } else { 0 }, iter_seed: fused.iter_counter, cooperative_grid_sync: 1, }; let ptrs = FusedPtrs { syn_bit: *sp.syn_bit_accessor().device_ptr(), syn_perm: *sp.syn_perm_accessor().device_ptr(), boost: *sp.boost_accessor().device_ptr(), active_duty: *sp.active_duty_accessor().device_ptr(), inhibition_threshold: *fused.inhibition_threshold.device_ptr(), seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(), seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(), syn_presyn: *tm.syn_presyn_accessor().device_ptr(), tm_syn_perm: *tm.syn_perm_accessor().device_ptr(), cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(), cell_active_a: *fused.cell_active_bits_a.device_ptr(), cell_active_b: *fused.cell_active_bits_b.device_ptr(), cell_winner_a: *fused.cell_winner_bits_a.device_ptr(), cell_winner_b: *fused.cell_winner_bits_b.device_ptr(), inputs: *inputs_flat.device_ptr(), cols_out: *cols_out.device_ptr(), anom_out: *anom_out.device_ptr(), barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB. step_scratch: *fused.step_scratch.device_ptr(), }; let grid_x = fused.grid_dim_x; let block_x = fused.block_dim_x; let cu_stream = *sp.dev_ref().cu_stream(); let use_cluster = fused.cluster_info.max_cluster_size > 0; unsafe { result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?; let mut kernel_params: [*mut std::ffi::c_void; 2] = [ (&ptrs as *const FusedPtrs).cast_mut().cast(), (&cfg as *const FusedConfig).cast_mut().cast(), ]; if use_cluster { // T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION. // cluster_dim=(16,1,1) maps the entire single-region grid into one cluster. let mut attr: sys::CUlaunchAttribute = std::mem::zeroed(); attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; attr.value.clusterDim.x = 16; attr.value.clusterDim.y = 1; attr.value.clusterDim.z = 1; let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed(); launch_cfg.gridDimX = grid_x; launch_cfg.gridDimY = 1; launch_cfg.gridDimZ = 1; launch_cfg.blockDimX = block_x; launch_cfg.blockDimY = 1; launch_cfg.blockDimZ = 1; launch_cfg.sharedMemBytes = 0; launch_cfg.hStream = cu_stream; launch_cfg.numAttrs = 1; launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute; let ret = sys::lib().cuLaunchKernelEx( &launch_cfg as *const sys::CUlaunchConfig, fused.raw_kernel.function, kernel_params.as_mut_ptr(), std::ptr::null_mut(), ); if ret != sys::CUresult::CUDA_SUCCESS { return Err(DriverError(ret)); } } else { // Pre-Hopper: cooperative kernel launch. The fused kernel uses // grid.sync() for cross-block synchronization which REQUIRES // cuLaunchCooperativeKernel (normal launch silently crashes on // the first grid.sync() call). let ret = sys::lib().cuLaunchCooperativeKernel( fused.raw_kernel.function, grid_x, 1, 1, block_x, 1, 1, 0, // sharedMemBytes cu_stream, kernel_params.as_mut_ptr(), ); if ret != sys::CUresult::CUDA_SUCCESS { return Err(DriverError(ret)); } } } Ok(()) } /// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel /// body; each block reads its region's FusedPtrs from a device-side array /// indexed by blockIdx.y. All regions share the same config (same /// input_bits/n_columns/etc.) so we pass one FusedConfig. /// /// This breaks through the CUDA cooperative-kernel device-level /// serialization: multiple cooperative launches are serialized regardless /// of stream, but one cooperative launch with grid.y=B processes all /// regions in a single invocation — ~B× speedup vs B sequential launches. #[allow(clippy::too_many_arguments)] /// Low-level raw-pointer entry, called by PyO3 binding which holds the /// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live, /// uniquely-borrowed region. All regions must be distinct. pub(super) fn launch_fused_batched_raw( region_ptrs: &[*mut super::HTMRegionGpu], inputs_per_region: &[u64], cols_per_region: &[u64], anom_per_region: &[u64], t: usize, input_bits: usize, learn: bool, ) -> Result<(), DriverError> { let b = region_ptrs.len(); assert_eq!(inputs_per_region.len(), b); assert_eq!(cols_per_region.len(), b); assert_eq!(anom_per_region.len(), b); assert!(b >= 1, "need at least one region"); // Reset per-region step_scratch before each launch. for &rp in region_ptrs.iter() { let r = unsafe { &mut *rp }; let dev = r.sp_gpu.dev_ref().clone(); dev.memset_zeros(&mut r.fused_state.step_scratch)?; r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1); } // Shared config — all regions use identical sp/tm parameters. let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = { let r0 = unsafe { &*region_ptrs[0] }; ( r0.fused_state.grid_dim_x, r0.fused_state.block_dim_x, r0.fused_state.raw_kernel.function_batched, *r0.sp_gpu.dev_ref().cu_stream(), *r0.sp_gpu.dev_ref().cu_primary_ctx(), ) }; let cfg = { let r = unsafe { &*region_ptrs[0] }; FusedConfig { input_bits: input_bits as u32, n_columns: r.sp_gpu.n_columns_accessor() as u32, synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32, conn_thr: r.sp_gpu.conn_thr_accessor(), sp_inc: r.sp_gpu.inc_accessor(), sp_dec: r.sp_gpu.dec_accessor(), sparsity_target: r.sp_gpu.sparsity_accessor(), duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0), thr_adapt_rate: 0.001f32, cells_per_column: r.tm_gpu.cells_per_column as u32, n_cells: r.tm_gpu.n_cells as u32, bits_words: r.tm_gpu.bits_words as u32, max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32, synapses_per_segment: MAX_SYN_PER_SEGMENT as u32, activation_threshold: r.tm_gpu.activation_threshold, learning_threshold: r.tm_gpu.learning_threshold, max_new_synapses: r.tm_gpu.max_new_synapse_count, conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32, perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32, perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32, predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32, initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32, t: t as u32, learn: if learn { 1 } else { 0 }, iter_seed: r.fused_state.iter_counter, cooperative_grid_sync: 1, } }; // Build B FusedPtrs per-region. let ptrs_vec: Vec = (0..b) .map(|i| { let r = unsafe { &*region_ptrs[i] }; FusedPtrs { syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(), syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(), boost: *r.sp_gpu.boost_accessor().device_ptr(), active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(), inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(), seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(), seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(), syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(), tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(), cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(), cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(), cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(), cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(), cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(), inputs: inputs_per_region[i], cols_out: cols_per_region[i], anom_out: anom_per_region[i], barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB. step_scratch: *r.fused_state.step_scratch.device_ptr(), } }) .collect(); // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes). // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it. let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone(); let ptrs_dev: CudaSlice = dev.htod_sync_copy(&ptrs_vec)?; let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr(); // T10: Cluster launch for batched regions. // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice) // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs). let use_cluster = { let r0 = unsafe { &*region_ptrs[0] }; r0.fused_state.cluster_info.max_cluster_size > 0 }; unsafe { result::ctx::set_current(cu_ctx)?; let mut kernel_params: [*mut std::ffi::c_void; 2] = [ (&ptrs_dev_ptr as *const u64).cast_mut().cast(), (&cfg as *const FusedConfig).cast_mut().cast(), ]; if use_cluster { let mut attr: sys::CUlaunchAttribute = std::mem::zeroed(); attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; attr.value.clusterDim.x = 16; attr.value.clusterDim.y = 1; attr.value.clusterDim.z = 1; let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed(); launch_cfg.gridDimX = grid_x; launch_cfg.gridDimY = b as u32; launch_cfg.gridDimZ = 1; launch_cfg.blockDimX = block_x; launch_cfg.blockDimY = 1; launch_cfg.blockDimZ = 1; launch_cfg.sharedMemBytes = 0; launch_cfg.hStream = cu_stream; launch_cfg.numAttrs = 1; launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute; let ret = sys::lib().cuLaunchKernelEx( &launch_cfg as *const sys::CUlaunchConfig, function_batched, kernel_params.as_mut_ptr(), std::ptr::null_mut(), ); if ret != sys::CUresult::CUDA_SUCCESS { return Err(DriverError(ret)); } } else { // Pre-Hopper: cooperative kernel launch (grid.sync() requires it). let ret = sys::lib().cuLaunchCooperativeKernel( function_batched, grid_x, b as u32, 1, block_x, 1, 1, 0, // sharedMemBytes cu_stream, kernel_params.as_mut_ptr(), ); if ret != sys::CUresult::CUDA_SUCCESS { return Err(DriverError(ret)); } } } Ok(()) }