icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
// Fused mean-reduction + boost-update kernel.
//
// Inputs:
// active_duty[n] (f32)
// boost_strength (f32)
//
// Output:
// boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
//
// Launch: single block (1024 threads), shared mem for reduction. At n=2048
// each thread handles 2 elements.
extern "C" __global__
void sp_boost_from_duty(
const float * __restrict__ active_duty, // (n,)
float * __restrict__ boost, // (n,) in-place out
float boost_strength,
unsigned int n
) {
extern __shared__ float smem_raw[];
float * smem = smem_raw;
const unsigned int tid = threadIdx.x;
const unsigned int bsz = blockDim.x;
// Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
float local_sum = 0.0f;
for (unsigned int i = tid; i < n; i += bsz) {
local_sum += active_duty[i];
}
// Warp reduction.
for (int off = 16; off > 0; off >>= 1) {
local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
}
unsigned int lane = tid & 31;
unsigned int warp = tid >> 5;
if (lane == 0) smem[warp] = local_sum;
__syncthreads();
// Warp 0 reduces warp-sums.
__shared__ float mean_s;
if (warp == 0) {
unsigned int nwarps = (bsz + 31) / 32;
float v = (lane < nwarps) ? smem[lane] : 0.0f;
for (int off = 16; off > 0; off >>= 1) {
v += __shfl_down_sync(0xffffffff, v, off);
}
if (tid == 0) {
mean_s = v / (float)n;
}
}
__syncthreads();
// Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
float mean = mean_s;
for (unsigned int i = tid; i < n; i += bsz) {
float d = active_duty[i] - mean;
boost[i] = expf(-boost_strength * d);
}
}