Spaces:
Runtime error
Runtime error
| // Fused mean-reduction + boost-update kernel. | |
| // | |
| // Inputs: | |
| // active_duty[n] (f32) | |
| // boost_strength (f32) | |
| // | |
| // Output: | |
| // boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean)) | |
| // | |
| // Launch: single block (1024 threads), shared mem for reduction. At n=2048 | |
| // each thread handles 2 elements. | |
| extern "C" __global__ | |
| void sp_boost_from_duty( | |
| const float * __restrict__ active_duty, // (n,) | |
| float * __restrict__ boost, // (n,) in-place out | |
| float boost_strength, | |
| unsigned int n | |
| ) { | |
| extern __shared__ float smem_raw[]; | |
| float * smem = smem_raw; | |
| const unsigned int tid = threadIdx.x; | |
| const unsigned int bsz = blockDim.x; | |
| // Phase 1: parallel sum of active_duty into smem[0..32] (warp-level). | |
| float local_sum = 0.0f; | |
| for (unsigned int i = tid; i < n; i += bsz) { | |
| local_sum += active_duty[i]; | |
| } | |
| // Warp reduction. | |
| for (int off = 16; off > 0; off >>= 1) { | |
| local_sum += __shfl_down_sync(0xffffffff, local_sum, off); | |
| } | |
| unsigned int lane = tid & 31; | |
| unsigned int warp = tid >> 5; | |
| if (lane == 0) smem[warp] = local_sum; | |
| __syncthreads(); | |
| // Warp 0 reduces warp-sums. | |
| __shared__ float mean_s; | |
| if (warp == 0) { | |
| unsigned int nwarps = (bsz + 31) / 32; | |
| float v = (lane < nwarps) ? smem[lane] : 0.0f; | |
| for (int off = 16; off > 0; off >>= 1) { | |
| v += __shfl_down_sync(0xffffffff, v, off); | |
| } | |
| if (tid == 0) { | |
| mean_s = v / (float)n; | |
| } | |
| } | |
| __syncthreads(); | |
| // Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)). | |
| float mean = mean_s; | |
| for (unsigned int i = tid; i < n; i += bsz) { | |
| float d = active_duty[i] - mean; | |
| boost[i] = expf(-boost_strength * d); | |
| } | |
| } | |