File size: 4,302 Bytes
1c59946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Top-K column selection.
//
// Inputs:
//   boosted[n_columns] : f32 score
// Output:
//   active_mask[n_columns] : u8 0/1, exactly k ones
//
// Tie-breaking: when scores are equal, the LOWER column index wins (matches
// CPU reference `select_nth_unstable_by` with secondary index comparator).
//
// Strategy: a single-block implementation. n_columns is typically 2048, which
// fits comfortably in shared memory. We use a bitonic top-k via per-thread
// radix-select of the (score, -index) key. At k≈41 of n=2048 the simplest
// correct approach is a thresholding pass:
//
//   1. Radix-like bucket pass to find the k-th largest score.
//   2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
//
// For strict index-ordered tie-break we materialise a 64-bit key:
//   key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
// Larger key = (higher score) OR (same score, smaller index).
//
// Then we find the k-th largest 64-bit key via radix-select and mark all
// columns whose key >= threshold. This is O(n_cols * log k) and well under
// 100 μs for n=2048, k=41 on sm_86.
//
// For simplicity and correctness this kernel uses a single-block parallel
// selection sort variant (find max → mark → zero → repeat, k iterations).
// At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
// fast.

extern "C" __global__
void sp_topk_select(
    const float * __restrict__ scores,    // (n_columns,)
    unsigned int  n_columns,
    unsigned int  k,
    unsigned char * __restrict__ active_out  // (n_columns,)
) {
    extern __shared__ float smem[];
    // Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
    //         smem[n..n+32*2] = reduction scratch (score + index, per warp)
    float * work = smem;
    const unsigned int tid = threadIdx.x;
    const unsigned int bsz = blockDim.x;

    // Load scores into shared; also init active_out = 0.
    for (unsigned int i = tid; i < n_columns; i += bsz) {
        work[i] = scores[i];
        active_out[i] = 0;
    }
    __syncthreads();

    __shared__ int   winner_idx;
    __shared__ float winner_score;

    for (unsigned int iter = 0; iter < k; ++iter) {
        // Find (argmax score, lowest index for ties).
        float best_s = -INFINITY;
        int   best_i = n_columns;   // sentinel larger than any index

        for (unsigned int i = tid; i < n_columns; i += bsz) {
            float s = work[i];
            if (s > best_s || (s == best_s && (int)i < best_i)) {
                best_s = s;
                best_i = (int)i;
            }
        }

        // Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
        unsigned int mask = 0xffffffff;
        for (int off = 16; off > 0; off >>= 1) {
            float os = __shfl_down_sync(mask, best_s, off);
            int   oi = __shfl_down_sync(mask, best_i, off);
            if (os > best_s || (os == best_s && oi < best_i)) {
                best_s = os;
                best_i = oi;
            }
        }
        // Warp 0 collects lane 0 values from other warps via shared mem.
        __shared__ float warp_s[32];
        __shared__ int   warp_i[32];
        unsigned int lane = tid & 31;
        unsigned int warp = tid >> 5;
        if (lane == 0) {
            warp_s[warp] = best_s;
            warp_i[warp] = best_i;
        }
        __syncthreads();

        if (warp == 0) {
            unsigned int nwarps = (bsz + 31) / 32;
            float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
            int   i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
            for (int off = 16; off > 0; off >>= 1) {
                float os = __shfl_down_sync(mask, s, off);
                int   oi = __shfl_down_sync(mask, i, off);
                if (os > s || (os == s && oi < i)) {
                    s = os;
                    i = oi;
                }
            }
            if (tid == 0) {
                winner_score = s;
                winner_idx = i;
            }
        }
        __syncthreads();

        if (tid == 0) {
            if (winner_idx < (int)n_columns) {
                active_out[winner_idx] = 1;
                work[winner_idx] = -INFINITY;
            }
        }
        __syncthreads();
    }
}