File size: 1,795 Bytes
1c59946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Fused mean-reduction + boost-update kernel.
//
// Inputs:
//   active_duty[n] (f32)
//   boost_strength (f32)
//
// Output:
//   boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
//
// Launch: single block (1024 threads), shared mem for reduction. At n=2048
// each thread handles 2 elements.

extern "C" __global__
void sp_boost_from_duty(
    const float * __restrict__ active_duty,  // (n,)
    float       * __restrict__ boost,        // (n,) in-place out
    float         boost_strength,
    unsigned int  n
) {
    extern __shared__ float smem_raw[];
    float * smem = smem_raw;
    const unsigned int tid = threadIdx.x;
    const unsigned int bsz = blockDim.x;

    // Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
    float local_sum = 0.0f;
    for (unsigned int i = tid; i < n; i += bsz) {
        local_sum += active_duty[i];
    }
    // Warp reduction.
    for (int off = 16; off > 0; off >>= 1) {
        local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
    }
    unsigned int lane = tid & 31;
    unsigned int warp = tid >> 5;
    if (lane == 0) smem[warp] = local_sum;
    __syncthreads();

    // Warp 0 reduces warp-sums.
    __shared__ float mean_s;
    if (warp == 0) {
        unsigned int nwarps = (bsz + 31) / 32;
        float v = (lane < nwarps) ? smem[lane] : 0.0f;
        for (int off = 16; off > 0; off >>= 1) {
            v += __shfl_down_sync(0xffffffff, v, off);
        }
        if (tid == 0) {
            mean_s = v / (float)n;
        }
    }
    __syncthreads();

    // Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
    float mean = mean_s;
    for (unsigned int i = tid; i < n; i += bsz) {
        float d = active_duty[i] - mean;
        boost[i] = expf(-boost_strength * d);
    }
}