Jackoatmon's picture
Update Feather a10g-large training runtime image
f8ad1c7 verified
//! Fused HTM megakernel launcher.
//!
//! Collapses the 12-kernel per-timestep pipeline (and the outer T-loop) into
//! a single kernel launch per forward. See `kernels/htm_fused_step.cu` for
//! the kernel design and the cross-block coherence strategy (grid barrier
//! via device counter with all blocks concurrently resident).
//!
//! Launch invariant: `grid_dim.x <= concurrent-block capacity`. Host code
//! probes the device SM count at construction and caps grid_dim.x
//! accordingly — otherwise the grid barrier deadlocks.
//!
//! Semantic change from the top-K pipeline: activation is per-column
//! threshold-based (local lateral inhibition) instead of global top-K.
//! A per-column `inhibition_threshold` is tracked and EMA-steered to hit
//! the sparsity target. This is a real architectural change and is
//! documented in `docs/GPU_HTM.md`.
#![cfg(feature = "gpu")]
use std::ffi::CString;
use std::sync::Arc;
use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
LaunchConfig};
use cudarc::nvrtc::Ptx;
use super::sp_gpu::SpatialPoolerGpu;
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
const PTX_HTM_FUSED: &str =
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
/// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
///
/// NOTE: `barrier_counters` is kept as an ABI-compat dummy (always 0). The
/// C-side `FusedPtrs` still has the field at the same byte offset; removing
/// it here would shift all subsequent fields and break the layout. Worker A
/// will eventually delete the field from both sides once the kernel is
/// updated; until then we zero it.
#[repr(C)]
#[derive(Clone, Copy)]
pub struct FusedPtrs {
pub syn_bit: u64,
pub syn_perm: u64,
pub boost: u64,
pub active_duty: u64,
pub inhibition_threshold: u64,
pub seg_cell_id: u64,
pub seg_syn_count: u64,
pub syn_presyn: u64,
pub tm_syn_perm: u64,
pub cell_seg_count: u64,
pub cell_active_a: u64,
pub cell_active_b: u64,
pub cell_winner_a: u64,
pub cell_winner_b: u64,
pub inputs: u64,
pub cols_out: u64,
pub anom_out: u64,
/// ABI-compat dummy — always 0. No device memory is allocated for this
/// field; the cluster barrier replaces the old software DLB barrier.
pub barrier_counters: u64,
pub step_scratch: u64,
}
unsafe impl DeviceRepr for FusedPtrs {}
/// Launch-time config — matches C-side `FusedConfig` 1:1.
#[repr(C)]
#[derive(Clone, Copy)]
pub struct FusedConfig {
pub input_bits: u32,
pub n_columns: u32,
pub synapses_per_col: u32,
pub conn_thr: f32,
pub sp_inc: f32,
pub sp_dec: f32,
pub sparsity_target: f32,
pub duty_alpha: f32,
pub thr_adapt_rate: f32,
pub cells_per_column: u32,
pub n_cells: u32,
pub bits_words: u32,
pub max_segments_per_cell: u32,
pub synapses_per_segment: u32,
pub activation_threshold: u32,
pub learning_threshold: u32,
pub max_new_synapses: u32,
pub conn_thr_i16: i32,
pub perm_inc_i16: i32,
pub perm_dec_i16: i32,
pub predicted_seg_dec_i16: i32,
pub initial_perm_i16: i32,
pub t: u32,
pub learn: u32,
pub iter_seed: u32,
pub cooperative_grid_sync: u32,
}
unsafe impl DeviceRepr for FusedConfig {}
/// Cluster launch parameters probed at construction time.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct ClusterInfo {
/// Maximum cluster size supported by this device (0 = cluster unsupported).
pub max_cluster_size: u32,
}
// There is only ONE launch mode: non-cooperative launch with Hopper Thread
// Block Cluster attribute (`CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`). The old
// software DLB barrier and the cooperative-launch path are both removed.
// Cluster barriers replace both.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct FusedLaunchPlan {
pub grid_dim_x: u32,
pub block_dim_x: u32,
pub cooperative_grid_limit: u32,
pub sm_count: u32,
}
fn fused_grid_cap_override() -> Option<u32> {
std::env::var("HTM_FUSED_GRID_CAP")
.ok()
.and_then(|s| s.parse::<u32>().ok())
.map(|v| v.max(1))
}
pub(crate) fn plan_fused_launch(
sm_count: u32,
cooperative_supported: bool,
cooperative_grid_limit: u32,
grid_cap_override: Option<u32>,
) -> Result<FusedLaunchPlan, String> {
let sm_count = sm_count.max(1);
// 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
// regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
// 256 regs/thread which is ample. Compensate with more blocks via
// cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
// 1024 works fine, but 256 is safe everywhere.
let block_dim_x = 256u32;
// Cluster launch path: cooperative launch is not required. Keep the probe
// result for residency estimation only.
if !cooperative_supported {
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
}
// Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
// Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
let default_grid_cap = 16u32;
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
let resident_bound = if cooperative_grid_limit > 0 {
cooperative_grid_limit.max(sm_count * 2)
} else {
sm_count * 2
};
Ok(FusedLaunchPlan {
grid_dim_x: resident_bound.min(grid_cap).max(1),
block_dim_x,
cooperative_grid_limit: resident_bound,
sm_count,
})
}
pub(super) struct RawFusedKernel {
module: sys::CUmodule,
pub(super) function: sys::CUfunction,
pub(super) function_batched: sys::CUfunction,
}
unsafe impl Send for RawFusedKernel {}
unsafe impl Sync for RawFusedKernel {}
impl Drop for RawFusedKernel {
fn drop(&mut self) {
unsafe {
let _ = result::module::unload(self.module);
}
}
}
/// Owns fused-path-only device state:
/// - per-column inhibition threshold (replaces global top-K)
/// - ping-pong cell_active/cell_winner bitsets
/// - step_scratch (n_active, n_unpred per timestep)
/// - cluster launch capability info
pub struct FusedState {
dev: Arc<CudaDevice>,
pub(super) raw_kernel: RawFusedKernel,
pub inhibition_threshold: CudaSlice<f32>,
pub cell_active_bits_a: CudaSlice<u32>,
pub cell_active_bits_b: CudaSlice<u32>,
pub cell_winner_bits_a: CudaSlice<u32>,
pub cell_winner_bits_b: CudaSlice<u32>,
pub step_scratch: CudaSlice<u32>, // length 6
pub grid_dim_x: u32,
pub block_dim_x: u32,
pub cooperative_grid_limit: u32,
pub iter_counter: u32,
/// Hopper cluster launch capability (0 = unsupported).
pub cluster_info: ClusterInfo,
// Config mirror (read-only after init).
#[allow(dead_code)]
pub initial_threshold: f32,
}
impl FusedState {
pub fn new(
dev: Arc<CudaDevice>,
n_columns: usize,
cells_per_column: usize,
initial_threshold: f32,
) -> Result<Self, DriverError> {
let n_cells = n_columns * cells_per_column;
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
let bits_words = n_cells / 32;
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
let init_vec = vec![initial_threshold; n_columns];
dev.htod_sync_copy_into(&init_vec, &mut inhibition_threshold)?;
let cell_active_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
let cell_active_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
let cell_winner_bits_a = dev.alloc_zeros::<u32>(bits_words)?;
let cell_winner_bits_b = dev.alloc_zeros::<u32>(bits_words)?;
let step_scratch = dev.alloc_zeros::<u32>(6)?;
unsafe {
result::ctx::set_current(*dev.cu_primary_ctx())?;
}
if dev.get_func("htm_fused", "htm_fused_step").is_none() {
dev.load_ptx(
Ptx::from_src(PTX_HTM_FUSED),
"htm_fused",
&["htm_fused_step", "htm_fused_step_batched"],
)?;
}
let ptx = CString::new(PTX_HTM_FUSED).expect("PTX contains no interior nul bytes");
let module = unsafe { result::module::load_data(ptx.as_ptr().cast()) }?;
let function = unsafe {
result::module::get_function(module, CString::new("htm_fused_step").unwrap())
}?;
let function_batched = unsafe {
result::module::get_function(module, CString::new("htm_fused_step_batched").unwrap())
}?;
// Cluster size 16 on Hopper is "non-portable" (> 8 requires opt-in).
// Must set CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED=1 on
// every launched kernel function, otherwise cuLaunchKernelEx rejects
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
unsafe {
let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
// Ignore errors: older CUDA may lack the attribute, in which case
// only portable sizes (<= 8) work — plan_fused_launch caps at 8.
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
let _ = sys::lib().cuFuncSetAttribute(function_batched, attr, 1);
}
// Probe SM count.
let sm_count = match dev.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
) {
Ok(v) => v as u32,
Err(_) => 16u32,
};
// T1: Probe Hopper cluster launch capability.
let max_cluster_size = match dev.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
) {
Ok(v) if v > 0 => {
// H200/sm_90a supports up to 16 blocks per cluster.
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
// Hopper maximum which is 16 (8 SMs × 2 blocks/SM = 16 blocks/cluster).
16u32
}
_ => 0u32,
};
eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
let cluster_info = ClusterInfo { max_cluster_size };
let cooperative_supported = matches!(
dev.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH),
Ok(v) if v > 0
);
let cooperative_grid_limit = if cooperative_supported {
let blocks_per_sm = unsafe {
result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0)
}
.ok()
.map(|v| v.max(0) as u32)
.unwrap_or(0);
sm_count.saturating_mul(blocks_per_sm)
} else {
0
};
let launch_plan = plan_fused_launch(
sm_count,
cooperative_supported,
cooperative_grid_limit,
fused_grid_cap_override(),
)
.map_err(|msg| {
// Surface as a CUDA-ish error so callers can propagate.
eprintln!("[htm_rust] FATAL: {msg}");
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
})?;
eprintln!(
"[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
cluster_info.max_cluster_size,
);
Ok(Self {
dev,
raw_kernel: RawFusedKernel { module, function, function_batched },
inhibition_threshold,
cell_active_bits_a,
cell_active_bits_b,
cell_winner_bits_a,
cell_winner_bits_b,
step_scratch,
grid_dim_x: launch_plan.grid_dim_x,
block_dim_x: launch_plan.block_dim_x,
cooperative_grid_limit: launch_plan.cooperative_grid_limit,
iter_counter: 0,
cluster_info,
initial_threshold,
})
}
/// Reset fused state. Called at region.reset().
pub fn reset(&mut self) -> Result<(), DriverError> {
self.dev.memset_zeros(&mut self.cell_active_bits_a)?;
self.dev.memset_zeros(&mut self.cell_active_bits_b)?;
self.dev.memset_zeros(&mut self.cell_winner_bits_a)?;
self.dev.memset_zeros(&mut self.cell_winner_bits_b)?;
self.dev.memset_zeros(&mut self.step_scratch)?;
// Do NOT reset inhibition_threshold — it's learned state. A hard
// reset of TM state should NOT forget the sparsity calibration.
Ok(())
}
}
/// Launch the fused megakernel. Processes all T timesteps in one kernel.
///
/// Uses `cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION=(16,1,1)`
/// when the device supports cluster launch, otherwise falls back to a plain
/// `launch_kernel`. For single-region launches, grid_dim_x <= 16 ensures the
/// entire grid fits in one cluster.
#[allow(clippy::too_many_arguments)]
pub fn launch_fused(
sp: &mut SpatialPoolerGpu,
tm: &mut TemporalMemoryGpu,
fused: &mut FusedState,
inputs_flat: &CudaSlice<u8>,
cols_out: &mut CudaSlice<u8>,
anom_out: &mut CudaSlice<f32>,
t: usize,
input_bits: usize,
learn: bool,
) -> Result<(), DriverError> {
// Reset step_scratch before each launch (safe re-entry).
sp.dev_ref().memset_zeros(&mut fused.step_scratch)?;
fused.iter_counter = fused.iter_counter.wrapping_add(1);
let cfg = FusedConfig {
input_bits: input_bits as u32,
n_columns: sp.n_columns_accessor() as u32,
synapses_per_col: sp.synapses_per_col_accessor() as u32,
conn_thr: sp.conn_thr_accessor(),
sp_inc: sp.inc_accessor(),
sp_dec: sp.dec_accessor(),
sparsity_target: sp.sparsity_accessor(),
duty_alpha: 1.0f32 / sp.duty_period_accessor().max(1.0),
thr_adapt_rate: 0.001f32,
cells_per_column: tm.cells_per_column as u32,
n_cells: tm.n_cells as u32,
bits_words: tm.bits_words as u32,
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
activation_threshold: tm.activation_threshold,
learning_threshold: tm.learning_threshold,
max_new_synapses: tm.max_new_synapse_count,
conn_thr_i16: tm.conn_thr_i16 as i32,
perm_inc_i16: tm.perm_inc_i16 as i32,
perm_dec_i16: tm.perm_dec_i16 as i32,
predicted_seg_dec_i16: tm.predicted_seg_dec_i16 as i32,
initial_perm_i16: tm.initial_perm_i16 as i32,
t: t as u32,
learn: if learn { 1 } else { 0 },
iter_seed: fused.iter_counter,
cooperative_grid_sync: 1,
};
let ptrs = FusedPtrs {
syn_bit: *sp.syn_bit_accessor().device_ptr(),
syn_perm: *sp.syn_perm_accessor().device_ptr(),
boost: *sp.boost_accessor().device_ptr(),
active_duty: *sp.active_duty_accessor().device_ptr(),
inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
seg_cell_id: *tm.seg_cell_id_accessor().device_ptr(),
seg_syn_count: *tm.seg_syn_count_accessor().device_ptr(),
syn_presyn: *tm.syn_presyn_accessor().device_ptr(),
tm_syn_perm: *tm.syn_perm_accessor().device_ptr(),
cell_seg_count: *tm.cell_seg_count_accessor().device_ptr(),
cell_active_a: *fused.cell_active_bits_a.device_ptr(),
cell_active_b: *fused.cell_active_bits_b.device_ptr(),
cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
inputs: *inputs_flat.device_ptr(),
cols_out: *cols_out.device_ptr(),
anom_out: *anom_out.device_ptr(),
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
step_scratch: *fused.step_scratch.device_ptr(),
};
let grid_x = fused.grid_dim_x;
let block_x = fused.block_dim_x;
let cu_stream = *sp.dev_ref().cu_stream();
let use_cluster = fused.cluster_info.max_cluster_size > 0;
unsafe {
result::ctx::set_current(*sp.dev_ref().cu_primary_ctx())?;
let mut kernel_params: [*mut std::ffi::c_void; 2] = [
(&ptrs as *const FusedPtrs).cast_mut().cast(),
(&cfg as *const FusedConfig).cast_mut().cast(),
];
if use_cluster {
// T10: Hopper cluster launch with CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION.
// cluster_dim=(16,1,1) maps the entire single-region grid into one cluster.
let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = 16;
attr.value.clusterDim.y = 1;
attr.value.clusterDim.z = 1;
let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
launch_cfg.gridDimX = grid_x;
launch_cfg.gridDimY = 1;
launch_cfg.gridDimZ = 1;
launch_cfg.blockDimX = block_x;
launch_cfg.blockDimY = 1;
launch_cfg.blockDimZ = 1;
launch_cfg.sharedMemBytes = 0;
launch_cfg.hStream = cu_stream;
launch_cfg.numAttrs = 1;
launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
let ret = sys::lib().cuLaunchKernelEx(
&launch_cfg as *const sys::CUlaunchConfig,
fused.raw_kernel.function,
kernel_params.as_mut_ptr(),
std::ptr::null_mut(),
);
if ret != sys::CUresult::CUDA_SUCCESS {
return Err(DriverError(ret));
}
} else {
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
// grid.sync() for cross-block synchronization which REQUIRES
// cuLaunchCooperativeKernel (normal launch silently crashes on
// the first grid.sync() call).
let ret = sys::lib().cuLaunchCooperativeKernel(
fused.raw_kernel.function,
grid_x, 1, 1,
block_x, 1, 1,
0, // sharedMemBytes
cu_stream,
kernel_params.as_mut_ptr(),
);
if ret != sys::CUresult::CUDA_SUCCESS {
return Err(DriverError(ret));
}
}
}
Ok(())
}
/// Single batched non-cooperative launch for B regions with DLB sync. Uses the same kernel
/// body; each block reads its region's FusedPtrs from a device-side array
/// indexed by blockIdx.y. All regions share the same config (same
/// input_bits/n_columns/etc.) so we pass one FusedConfig.
///
/// This breaks through the CUDA cooperative-kernel device-level
/// serialization: multiple cooperative launches are serialized regardless
/// of stream, but one cooperative launch with grid.y=B processes all
/// regions in a single invocation — ~B× speedup vs B sequential launches.
#[allow(clippy::too_many_arguments)]
/// Low-level raw-pointer entry, called by PyO3 binding which holds the
/// mutable borrows. Safety: each `*mut HTMRegionGpu` must point to a live,
/// uniquely-borrowed region. All regions must be distinct.
pub(super) fn launch_fused_batched_raw(
region_ptrs: &[*mut super::HTMRegionGpu],
inputs_per_region: &[u64],
cols_per_region: &[u64],
anom_per_region: &[u64],
t: usize,
input_bits: usize,
learn: bool,
) -> Result<(), DriverError> {
let b = region_ptrs.len();
assert_eq!(inputs_per_region.len(), b);
assert_eq!(cols_per_region.len(), b);
assert_eq!(anom_per_region.len(), b);
assert!(b >= 1, "need at least one region");
// Reset per-region step_scratch before each launch.
for &rp in region_ptrs.iter() {
let r = unsafe { &mut *rp };
let dev = r.sp_gpu.dev_ref().clone();
dev.memset_zeros(&mut r.fused_state.step_scratch)?;
r.fused_state.iter_counter = r.fused_state.iter_counter.wrapping_add(1);
}
// Shared config — all regions use identical sp/tm parameters.
let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
let r0 = unsafe { &*region_ptrs[0] };
(
r0.fused_state.grid_dim_x,
r0.fused_state.block_dim_x,
r0.fused_state.raw_kernel.function_batched,
*r0.sp_gpu.dev_ref().cu_stream(),
*r0.sp_gpu.dev_ref().cu_primary_ctx(),
)
};
let cfg = {
let r = unsafe { &*region_ptrs[0] };
FusedConfig {
input_bits: input_bits as u32,
n_columns: r.sp_gpu.n_columns_accessor() as u32,
synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
conn_thr: r.sp_gpu.conn_thr_accessor(),
sp_inc: r.sp_gpu.inc_accessor(),
sp_dec: r.sp_gpu.dec_accessor(),
sparsity_target: r.sp_gpu.sparsity_accessor(),
duty_alpha: 1.0f32 / r.sp_gpu.duty_period_accessor().max(1.0),
thr_adapt_rate: 0.001f32,
cells_per_column: r.tm_gpu.cells_per_column as u32,
n_cells: r.tm_gpu.n_cells as u32,
bits_words: r.tm_gpu.bits_words as u32,
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
activation_threshold: r.tm_gpu.activation_threshold,
learning_threshold: r.tm_gpu.learning_threshold,
max_new_synapses: r.tm_gpu.max_new_synapse_count,
conn_thr_i16: r.tm_gpu.conn_thr_i16 as i32,
perm_inc_i16: r.tm_gpu.perm_inc_i16 as i32,
perm_dec_i16: r.tm_gpu.perm_dec_i16 as i32,
predicted_seg_dec_i16: r.tm_gpu.predicted_seg_dec_i16 as i32,
initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
t: t as u32,
learn: if learn { 1 } else { 0 },
iter_seed: r.fused_state.iter_counter,
cooperative_grid_sync: 1,
}
};
// Build B FusedPtrs per-region.
let ptrs_vec: Vec<FusedPtrs> = (0..b)
.map(|i| {
let r = unsafe { &*region_ptrs[i] };
FusedPtrs {
syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
boost: *r.sp_gpu.boost_accessor().device_ptr(),
active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
inhibition_threshold: *r.fused_state.inhibition_threshold.device_ptr(),
seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
cell_active_a: *r.fused_state.cell_active_bits_a.device_ptr(),
cell_active_b: *r.fused_state.cell_active_bits_b.device_ptr(),
cell_winner_a: *r.fused_state.cell_winner_bits_a.device_ptr(),
cell_winner_b: *r.fused_state.cell_winner_bits_b.device_ptr(),
inputs: inputs_per_region[i],
cols_out: cols_per_region[i],
anom_out: anom_per_region[i],
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
step_scratch: *r.fused_state.step_scratch.device_ptr(),
}
})
.collect();
// Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
// FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
let dev = unsafe { &*region_ptrs[0] }.sp_gpu.dev_ref().clone();
let ptrs_dev: CudaSlice<FusedPtrs> = dev.htod_sync_copy(&ptrs_vec)?;
let ptrs_dev_ptr: u64 = *ptrs_dev.device_ptr();
// T10: Cluster launch for batched regions.
// Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
// occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
// on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
let use_cluster = {
let r0 = unsafe { &*region_ptrs[0] };
r0.fused_state.cluster_info.max_cluster_size > 0
};
unsafe {
result::ctx::set_current(cu_ctx)?;
let mut kernel_params: [*mut std::ffi::c_void; 2] = [
(&ptrs_dev_ptr as *const u64).cast_mut().cast(),
(&cfg as *const FusedConfig).cast_mut().cast(),
];
if use_cluster {
let mut attr: sys::CUlaunchAttribute = std::mem::zeroed();
attr.id = sys::CUlaunchAttributeID::CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = 16;
attr.value.clusterDim.y = 1;
attr.value.clusterDim.z = 1;
let mut launch_cfg: sys::CUlaunchConfig = std::mem::zeroed();
launch_cfg.gridDimX = grid_x;
launch_cfg.gridDimY = b as u32;
launch_cfg.gridDimZ = 1;
launch_cfg.blockDimX = block_x;
launch_cfg.blockDimY = 1;
launch_cfg.blockDimZ = 1;
launch_cfg.sharedMemBytes = 0;
launch_cfg.hStream = cu_stream;
launch_cfg.numAttrs = 1;
launch_cfg.attrs = &mut attr as *mut sys::CUlaunchAttribute;
let ret = sys::lib().cuLaunchKernelEx(
&launch_cfg as *const sys::CUlaunchConfig,
function_batched,
kernel_params.as_mut_ptr(),
std::ptr::null_mut(),
);
if ret != sys::CUresult::CUDA_SUCCESS {
return Err(DriverError(ret));
}
} else {
// Pre-Hopper: cooperative kernel launch (grid.sync() requires it).
let ret = sys::lib().cuLaunchCooperativeKernel(
function_batched,
grid_x, b as u32, 1,
block_x, 1, 1,
0, // sharedMemBytes
cu_stream,
kernel_params.as_mut_ptr(),
);
if ret != sys::CUresult::CUDA_SUCCESS {
return Err(DriverError(ret));
}
}
}
Ok(())
}