// Fused mean-reduction + boost-update kernel. // // Inputs: // active_duty[n] (f32) // boost_strength (f32) // // Output: // boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean)) // // Launch: single block (1024 threads), shared mem for reduction. At n=2048 // each thread handles 2 elements. extern "C" __global__ void sp_boost_from_duty( const float * __restrict__ active_duty, // (n,) float * __restrict__ boost, // (n,) in-place out float boost_strength, unsigned int n ) { extern __shared__ float smem_raw[]; float * smem = smem_raw; const unsigned int tid = threadIdx.x; const unsigned int bsz = blockDim.x; // Phase 1: parallel sum of active_duty into smem[0..32] (warp-level). float local_sum = 0.0f; for (unsigned int i = tid; i < n; i += bsz) { local_sum += active_duty[i]; } // Warp reduction. for (int off = 16; off > 0; off >>= 1) { local_sum += __shfl_down_sync(0xffffffff, local_sum, off); } unsigned int lane = tid & 31; unsigned int warp = tid >> 5; if (lane == 0) smem[warp] = local_sum; __syncthreads(); // Warp 0 reduces warp-sums. __shared__ float mean_s; if (warp == 0) { unsigned int nwarps = (bsz + 31) / 32; float v = (lane < nwarps) ? smem[lane] : 0.0f; for (int off = 16; off > 0; off >>= 1) { v += __shfl_down_sync(0xffffffff, v, off); } if (tid == 0) { mean_s = v / (float)n; } } __syncthreads(); // Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)). float mean = mean_s; for (unsigned int i = tid; i < n; i += bsz) { float d = active_duty[i] - mean; boost[i] = expf(-boost_strength * d); } }