File size: 12,056 Bytes
5426fd3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 | /*
* Optimized RMSNorm CUDA Kernel for Qwen3-8B
* Optimized for NVIDIA H100 (sm_90)
*
* RMSNorm formula: output = x * weight / sqrt(mean(x²) + eps)
*
* Qwen3-8B specific:
* - hidden_size: 4096
* - rms_norm_eps: 1e-6
* - 65 RMSNorm modules (32 layers * 2 + 1 final)
*
* H100 Optimizations:
* - Vectorized loads/stores (__nv_bfloat162/__half2) for maximum memory bandwidth
* - Warp shuffle reductions (no shared memory bank conflicts)
* - Coalesced memory access patterns
* - Block size tuned for 132 SMs
*/
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cmath>
constexpr int WARP_SIZE = 32;
constexpr int MAX_THREADS = 1024;
// Warp-level reduction using shuffle operations
template <typename T>
__device__ __forceinline__ T warp_reduce_sum(T val) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, offset);
}
return val;
}
// Block-level reduction using shared memory
template <typename T>
__device__ __forceinline__ T block_reduce_sum(T val, T* shared) {
const int lane = threadIdx.x % WARP_SIZE;
const int wid = threadIdx.x / WARP_SIZE;
// Warp-level reduction
val = warp_reduce_sum(val);
// Write warp results to shared memory
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
// Final reduction in first warp
const int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : T(0);
if (wid == 0) {
val = warp_reduce_sum(val);
}
return val;
}
// Helper functions for type conversion
__device__ __forceinline__ float to_float(float x) { return x; }
__device__ __forceinline__ float to_float(__half x) { return __half2float(x); }
__device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); }
__device__ __forceinline__ float from_float(float x, float*) { return x; }
__device__ __forceinline__ __half from_float(float x, __half*) { return __float2half(x); }
__device__ __forceinline__ __nv_bfloat16 from_float(float x, __nv_bfloat16*) { return __float2bfloat16(x); }
// ============================================================================
// BF16-specific optimized kernel using __nv_bfloat162 for 2-element vectorization
// Optimized for Qwen3 hidden_size=4096 (even, >= 64)
// ============================================================================
__global__ void rmsnorm_kernel_bf16_vectorized(
__nv_bfloat16* __restrict__ output,
const __nv_bfloat16* __restrict__ input,
const __nv_bfloat16* __restrict__ weight,
const int hidden_size,
const float eps
) {
extern __shared__ char smem[];
float* shared = reinterpret_cast<float*>(smem);
const int row = blockIdx.x;
const int tid = threadIdx.x;
const int stride = blockDim.x;
const __nv_bfloat16* row_input = input + row * hidden_size;
__nv_bfloat16* row_output = output + row * hidden_size;
// Phase 1: Compute sum of squares with bf16x2 vectorized loads
float sum_sq = 0.0f;
// Use __nv_bfloat162 for 2-element vectorized loads
const int vec_hidden = hidden_size / 2;
const __nv_bfloat162* vec_input = reinterpret_cast<const __nv_bfloat162*>(row_input);
#pragma unroll 4
for (int i = tid; i < vec_hidden; i += stride) {
__nv_bfloat162 v = vec_input[i];
float v0 = __bfloat162float(v.x);
float v1 = __bfloat162float(v.y);
sum_sq += v0 * v0 + v1 * v1;
}
// Handle odd element if hidden_size is odd (not the case for Qwen3)
if (hidden_size % 2 == 1 && tid == 0) {
float v = __bfloat162float(row_input[hidden_size - 1]);
sum_sq += v * v;
}
// Reduce across block
sum_sq = block_reduce_sum(sum_sq, shared);
// Compute RMS inverse
__shared__ float rms_inv;
if (tid == 0) {
float mean_sq = sum_sq / static_cast<float>(hidden_size);
rms_inv = rsqrtf(mean_sq + eps);
}
__syncthreads();
const float factor = rms_inv;
// Phase 2: Apply normalization and weight with bf16x2 vectorized stores
const __nv_bfloat162* vec_weight = reinterpret_cast<const __nv_bfloat162*>(weight);
__nv_bfloat162* vec_output = reinterpret_cast<__nv_bfloat162*>(row_output);
#pragma unroll 4
for (int i = tid; i < vec_hidden; i += stride) {
__nv_bfloat162 v_in = vec_input[i];
__nv_bfloat162 v_w = vec_weight[i];
float v0 = __bfloat162float(v_in.x);
float v1 = __bfloat162float(v_in.y);
float w0 = __bfloat162float(v_w.x);
float w1 = __bfloat162float(v_w.y);
__nv_bfloat162 result;
result.x = __float2bfloat16(v0 * factor * w0);
result.y = __float2bfloat16(v1 * factor * w1);
vec_output[i] = result;
}
// Handle odd element
if (hidden_size % 2 == 1 && tid == 0) {
float v = __bfloat162float(row_input[hidden_size - 1]);
float w = __bfloat162float(weight[hidden_size - 1]);
row_output[hidden_size - 1] = __float2bfloat16(v * factor * w);
}
}
// ============================================================================
// FP16-specific optimized kernel using __half2 for 2-element vectorization
// ============================================================================
__global__ void rmsnorm_kernel_fp16_vectorized(
__half* __restrict__ output,
const __half* __restrict__ input,
const __half* __restrict__ weight,
const int hidden_size,
const float eps
) {
extern __shared__ char smem[];
float* shared = reinterpret_cast<float*>(smem);
const int row = blockIdx.x;
const int tid = threadIdx.x;
const int stride = blockDim.x;
const __half* row_input = input + row * hidden_size;
__half* row_output = output + row * hidden_size;
// Phase 1: Compute sum of squares with half2 vectorized loads
float sum_sq = 0.0f;
const int vec_hidden = hidden_size / 2;
const __half2* vec_input = reinterpret_cast<const __half2*>(row_input);
#pragma unroll 4
for (int i = tid; i < vec_hidden; i += stride) {
__half2 v = vec_input[i];
float v0 = __half2float(v.x);
float v1 = __half2float(v.y);
sum_sq += v0 * v0 + v1 * v1;
}
// Handle odd element if hidden_size is odd
if (hidden_size % 2 == 1 && tid == 0) {
float v = __half2float(row_input[hidden_size - 1]);
sum_sq += v * v;
}
// Reduce across block
sum_sq = block_reduce_sum(sum_sq, shared);
// Compute RMS inverse
__shared__ float rms_inv;
if (tid == 0) {
float mean_sq = sum_sq / static_cast<float>(hidden_size);
rms_inv = rsqrtf(mean_sq + eps);
}
__syncthreads();
const float factor = rms_inv;
// Phase 2: Apply normalization with half2 vectorized stores
const __half2* vec_weight = reinterpret_cast<const __half2*>(weight);
__half2* vec_output = reinterpret_cast<__half2*>(row_output);
#pragma unroll 4
for (int i = tid; i < vec_hidden; i += stride) {
__half2 v_in = vec_input[i];
__half2 v_w = vec_weight[i];
float v0 = __half2float(v_in.x);
float v1 = __half2float(v_in.y);
float w0 = __half2float(v_w.x);
float w1 = __half2float(v_w.y);
__half2 result;
result.x = __float2half(v0 * factor * w0);
result.y = __float2half(v1 * factor * w1);
vec_output[i] = result;
}
// Handle odd element
if (hidden_size % 2 == 1 && tid == 0) {
float v = __half2float(row_input[hidden_size - 1]);
float w = __half2float(weight[hidden_size - 1]);
row_output[hidden_size - 1] = __float2half(v * factor * w);
}
}
// ============================================================================
// Generic scalar kernel (fallback)
// ============================================================================
template <typename scalar_t, typename acc_t = float>
__global__ void rmsnorm_kernel(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ weight,
const int hidden_size,
const float eps
) {
extern __shared__ char smem[];
acc_t* shared = reinterpret_cast<acc_t*>(smem);
const int row = blockIdx.x;
const int tid = threadIdx.x;
const int stride = blockDim.x;
const scalar_t* row_input = input + row * hidden_size;
scalar_t* row_output = output + row * hidden_size;
// Compute sum of squares
acc_t sum_sq = 0.0f;
for (int i = tid; i < hidden_size; i += stride) {
acc_t val = to_float(row_input[i]);
sum_sq += val * val;
}
// Reduce across block
sum_sq = block_reduce_sum(sum_sq, shared);
// Compute RMS
__shared__ acc_t rms_inv;
if (tid == 0) {
acc_t mean_sq = sum_sq / static_cast<acc_t>(hidden_size);
rms_inv = rsqrtf(mean_sq + eps);
}
__syncthreads();
// Apply normalization and weight
for (int i = tid; i < hidden_size; i += stride) {
acc_t val = to_float(row_input[i]);
acc_t w = to_float(weight[i]);
row_output[i] = from_float(val * rms_inv * w, (scalar_t*)nullptr);
}
}
// ============================================================================
// Launch functions
// ============================================================================
extern "C" {
void rmsnorm_forward_fp16(
__half* output,
const __half* input,
const __half* weight,
const int batch_size,
const int seq_len,
const int hidden_size,
const float eps,
cudaStream_t stream
) {
const int num_rows = batch_size * seq_len;
int threads = min(hidden_size / 2, MAX_THREADS);
threads = max(threads, WARP_SIZE);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
size_t smem_size = ((threads + WARP_SIZE - 1) / WARP_SIZE) * sizeof(float);
if (hidden_size % 2 == 0 && hidden_size >= 64) {
rmsnorm_kernel_fp16_vectorized<<<num_rows, threads, smem_size, stream>>>(
output, input, weight, hidden_size, eps
);
} else {
threads = min(hidden_size, MAX_THREADS);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
rmsnorm_kernel<__half><<<num_rows, threads, smem_size, stream>>>(
output, input, weight, hidden_size, eps
);
}
}
void rmsnorm_forward_bf16(
__nv_bfloat16* output,
const __nv_bfloat16* input,
const __nv_bfloat16* weight,
const int batch_size,
const int seq_len,
const int hidden_size,
const float eps,
cudaStream_t stream
) {
const int num_rows = batch_size * seq_len;
int threads = min(hidden_size / 2, MAX_THREADS);
threads = max(threads, WARP_SIZE);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
size_t smem_size = ((threads + WARP_SIZE - 1) / WARP_SIZE) * sizeof(float);
if (hidden_size % 2 == 0 && hidden_size >= 64) {
rmsnorm_kernel_bf16_vectorized<<<num_rows, threads, smem_size, stream>>>(
output, input, weight, hidden_size, eps
);
} else {
threads = min(hidden_size, MAX_THREADS);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
rmsnorm_kernel<__nv_bfloat16><<<num_rows, threads, smem_size, stream>>>(
output, input, weight, hidden_size, eps
);
}
}
void rmsnorm_forward_fp32(
float* output,
const float* input,
const float* weight,
const int batch_size,
const int seq_len,
const int hidden_size,
const float eps,
cudaStream_t stream
) {
const int num_rows = batch_size * seq_len;
int threads = min(hidden_size, MAX_THREADS);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
size_t smem_size = ((threads + WARP_SIZE - 1) / WARP_SIZE) * sizeof(float);
rmsnorm_kernel<float><<<num_rows, threads, smem_size, stream>>>(
output, input, weight, hidden_size, eps
);
}
} // extern "C"
|