// Top-K column selection. // // Inputs: // boosted[n_columns] : f32 score // Output: // active_mask[n_columns] : u8 0/1, exactly k ones // // Tie-breaking: when scores are equal, the LOWER column index wins (matches // CPU reference `select_nth_unstable_by` with secondary index comparator). // // Strategy: a single-block implementation. n_columns is typically 2048, which // fits comfortably in shared memory. We use a bitonic top-k via per-thread // radix-select of the (score, -index) key. At k≈41 of n=2048 the simplest // correct approach is a thresholding pass: // // 1. Radix-like bucket pass to find the k-th largest score. // 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k. // // For strict index-ordered tie-break we materialise a 64-bit key: // key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index) // Larger key = (higher score) OR (same score, smaller index). // // Then we find the k-th largest 64-bit key via radix-select and mark all // columns whose key >= threshold. This is O(n_cols * log k) and well under // 100 μs for n=2048, k=41 on sm_86. // // For simplicity and correctness this kernel uses a single-block parallel // selection sort variant (find max → mark → zero → repeat, k iterations). // At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially // fast. extern "C" __global__ void sp_topk_select( const float * __restrict__ scores, // (n_columns,) unsigned int n_columns, unsigned int k, unsigned char * __restrict__ active_out // (n_columns,) ) { extern __shared__ float smem[]; // Layout: smem[0..n] = working scores (we'll mark selected entries as -inf) // smem[n..n+32*2] = reduction scratch (score + index, per warp) float * work = smem; const unsigned int tid = threadIdx.x; const unsigned int bsz = blockDim.x; // Load scores into shared; also init active_out = 0. for (unsigned int i = tid; i < n_columns; i += bsz) { work[i] = scores[i]; active_out[i] = 0; } __syncthreads(); __shared__ int winner_idx; __shared__ float winner_score; for (unsigned int iter = 0; iter < k; ++iter) { // Find (argmax score, lowest index for ties). float best_s = -INFINITY; int best_i = n_columns; // sentinel larger than any index for (unsigned int i = tid; i < n_columns; i += bsz) { float s = work[i]; if (s > best_s || (s == best_s && (int)i < best_i)) { best_s = s; best_i = (int)i; } } // Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie). unsigned int mask = 0xffffffff; for (int off = 16; off > 0; off >>= 1) { float os = __shfl_down_sync(mask, best_s, off); int oi = __shfl_down_sync(mask, best_i, off); if (os > best_s || (os == best_s && oi < best_i)) { best_s = os; best_i = oi; } } // Warp 0 collects lane 0 values from other warps via shared mem. __shared__ float warp_s[32]; __shared__ int warp_i[32]; unsigned int lane = tid & 31; unsigned int warp = tid >> 5; if (lane == 0) { warp_s[warp] = best_s; warp_i[warp] = best_i; } __syncthreads(); if (warp == 0) { unsigned int nwarps = (bsz + 31) / 32; float s = (lane < nwarps) ? warp_s[lane] : -INFINITY; int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns; for (int off = 16; off > 0; off >>= 1) { float os = __shfl_down_sync(mask, s, off); int oi = __shfl_down_sync(mask, i, off); if (os > s || (os == s && oi < i)) { s = os; i = oi; } } if (tid == 0) { winner_score = s; winner_idx = i; } } __syncthreads(); if (tid == 0) { if (winner_idx < (int)n_columns) { active_out[winner_idx] = 1; work[winner_idx] = -INFINITY; } } __syncthreads(); } }