icarus112's picture
Upload folder using huggingface_hub
1c59946 verified
// Top-K column selection.
//
// Inputs:
// boosted[n_columns] : f32 score
// Output:
// active_mask[n_columns] : u8 0/1, exactly k ones
//
// Tie-breaking: when scores are equal, the LOWER column index wins (matches
// CPU reference `select_nth_unstable_by` with secondary index comparator).
//
// Strategy: a single-block implementation. n_columns is typically 2048, which
// fits comfortably in shared memory. We use a bitonic top-k via per-thread
// radix-select of the (score, -index) key. At k≈41 of n=2048 the simplest
// correct approach is a thresholding pass:
//
// 1. Radix-like bucket pass to find the k-th largest score.
// 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
//
// For strict index-ordered tie-break we materialise a 64-bit key:
// key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
// Larger key = (higher score) OR (same score, smaller index).
//
// Then we find the k-th largest 64-bit key via radix-select and mark all
// columns whose key >= threshold. This is O(n_cols * log k) and well under
// 100 μs for n=2048, k=41 on sm_86.
//
// For simplicity and correctness this kernel uses a single-block parallel
// selection sort variant (find max → mark → zero → repeat, k iterations).
// At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
// fast.
extern "C" __global__
void sp_topk_select(
const float * __restrict__ scores, // (n_columns,)
unsigned int n_columns,
unsigned int k,
unsigned char * __restrict__ active_out // (n_columns,)
) {
extern __shared__ float smem[];
// Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
// smem[n..n+32*2] = reduction scratch (score + index, per warp)
float * work = smem;
const unsigned int tid = threadIdx.x;
const unsigned int bsz = blockDim.x;
// Load scores into shared; also init active_out = 0.
for (unsigned int i = tid; i < n_columns; i += bsz) {
work[i] = scores[i];
active_out[i] = 0;
}
__syncthreads();
__shared__ int winner_idx;
__shared__ float winner_score;
for (unsigned int iter = 0; iter < k; ++iter) {
// Find (argmax score, lowest index for ties).
float best_s = -INFINITY;
int best_i = n_columns; // sentinel larger than any index
for (unsigned int i = tid; i < n_columns; i += bsz) {
float s = work[i];
if (s > best_s || (s == best_s && (int)i < best_i)) {
best_s = s;
best_i = (int)i;
}
}
// Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
unsigned int mask = 0xffffffff;
for (int off = 16; off > 0; off >>= 1) {
float os = __shfl_down_sync(mask, best_s, off);
int oi = __shfl_down_sync(mask, best_i, off);
if (os > best_s || (os == best_s && oi < best_i)) {
best_s = os;
best_i = oi;
}
}
// Warp 0 collects lane 0 values from other warps via shared mem.
__shared__ float warp_s[32];
__shared__ int warp_i[32];
unsigned int lane = tid & 31;
unsigned int warp = tid >> 5;
if (lane == 0) {
warp_s[warp] = best_s;
warp_i[warp] = best_i;
}
__syncthreads();
if (warp == 0) {
unsigned int nwarps = (bsz + 31) / 32;
float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
for (int off = 16; off > 0; off >>= 1) {
float os = __shfl_down_sync(mask, s, off);
int oi = __shfl_down_sync(mask, i, off);
if (os > s || (os == s && oi < i)) {
s = os;
i = oi;
}
}
if (tid == 0) {
winner_score = s;
winner_idx = i;
}
}
__syncthreads();
if (tid == 0) {
if (winner_idx < (int)n_columns) {
active_out[winner_idx] = 1;
work[winner_idx] = -INFINITY;
}
}
__syncthreads();
}
}