| // Top-K column selection. | |
| // | |
| // Inputs: | |
| // boosted[n_columns] : f32 score | |
| // Output: | |
| // active_mask[n_columns] : u8 0/1, exactly k ones | |
| // | |
| // Tie-breaking: when scores are equal, the LOWER column index wins (matches | |
| // CPU reference `select_nth_unstable_by` with secondary index comparator). | |
| // | |
| // Strategy: a single-block implementation. n_columns is typically 2048, which | |
| // fits comfortably in shared memory. We use a bitonic top-k via per-thread | |
| // radix-select of the (score, -index) key. At k≈41 of n=2048 the simplest | |
| // correct approach is a thresholding pass: | |
| // | |
| // 1. Radix-like bucket pass to find the k-th largest score. | |
| // 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k. | |
| // | |
| // For strict index-ordered tie-break we materialise a 64-bit key: | |
| // key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index) | |
| // Larger key = (higher score) OR (same score, smaller index). | |
| // | |
| // Then we find the k-th largest 64-bit key via radix-select and mark all | |
| // columns whose key >= threshold. This is O(n_cols * log k) and well under | |
| // 100 μs for n=2048, k=41 on sm_86. | |
| // | |
| // For simplicity and correctness this kernel uses a single-block parallel | |
| // selection sort variant (find max → mark → zero → repeat, k iterations). | |
| // At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially | |
| // fast. | |
| extern "C" __global__ | |
| void sp_topk_select( | |
| const float * __restrict__ scores, // (n_columns,) | |
| unsigned int n_columns, | |
| unsigned int k, | |
| unsigned char * __restrict__ active_out // (n_columns,) | |
| ) { | |
| extern __shared__ float smem[]; | |
| // Layout: smem[0..n] = working scores (we'll mark selected entries as -inf) | |
| // smem[n..n+32*2] = reduction scratch (score + index, per warp) | |
| float * work = smem; | |
| const unsigned int tid = threadIdx.x; | |
| const unsigned int bsz = blockDim.x; | |
| // Load scores into shared; also init active_out = 0. | |
| for (unsigned int i = tid; i < n_columns; i += bsz) { | |
| work[i] = scores[i]; | |
| active_out[i] = 0; | |
| } | |
| __syncthreads(); | |
| __shared__ int winner_idx; | |
| __shared__ float winner_score; | |
| for (unsigned int iter = 0; iter < k; ++iter) { | |
| // Find (argmax score, lowest index for ties). | |
| float best_s = -INFINITY; | |
| int best_i = n_columns; // sentinel larger than any index | |
| for (unsigned int i = tid; i < n_columns; i += bsz) { | |
| float s = work[i]; | |
| if (s > best_s || (s == best_s && (int)i < best_i)) { | |
| best_s = s; | |
| best_i = (int)i; | |
| } | |
| } | |
| // Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie). | |
| unsigned int mask = 0xffffffff; | |
| for (int off = 16; off > 0; off >>= 1) { | |
| float os = __shfl_down_sync(mask, best_s, off); | |
| int oi = __shfl_down_sync(mask, best_i, off); | |
| if (os > best_s || (os == best_s && oi < best_i)) { | |
| best_s = os; | |
| best_i = oi; | |
| } | |
| } | |
| // Warp 0 collects lane 0 values from other warps via shared mem. | |
| __shared__ float warp_s[32]; | |
| __shared__ int warp_i[32]; | |
| unsigned int lane = tid & 31; | |
| unsigned int warp = tid >> 5; | |
| if (lane == 0) { | |
| warp_s[warp] = best_s; | |
| warp_i[warp] = best_i; | |
| } | |
| __syncthreads(); | |
| if (warp == 0) { | |
| unsigned int nwarps = (bsz + 31) / 32; | |
| float s = (lane < nwarps) ? warp_s[lane] : -INFINITY; | |
| int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns; | |
| for (int off = 16; off > 0; off >>= 1) { | |
| float os = __shfl_down_sync(mask, s, off); | |
| int oi = __shfl_down_sync(mask, i, off); | |
| if (os > s || (os == s && oi < i)) { | |
| s = os; | |
| i = oi; | |
| } | |
| } | |
| if (tid == 0) { | |
| winner_score = s; | |
| winner_idx = i; | |
| } | |
| } | |
| __syncthreads(); | |
| if (tid == 0) { | |
| if (winner_idx < (int)n_columns) { | |
| active_out[winner_idx] = 1; | |
| work[winner_idx] = -INFINITY; | |
| } | |
| } | |
| __syncthreads(); | |
| } | |
| } | |