burtenshaw HF Staff commited on
Commit
5426fd3
·
verified ·
1 Parent(s): 0286af4

Upload kernel_src/rmsnorm.cu with huggingface_hub

Browse files
Files changed (1) hide show
  1. kernel_src/rmsnorm.cu +369 -0
kernel_src/rmsnorm.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Optimized RMSNorm CUDA Kernel for Qwen3-8B
3
+ * Optimized for NVIDIA H100 (sm_90)
4
+ *
5
+ * RMSNorm formula: output = x * weight / sqrt(mean(x²) + eps)
6
+ *
7
+ * Qwen3-8B specific:
8
+ * - hidden_size: 4096
9
+ * - rms_norm_eps: 1e-6
10
+ * - 65 RMSNorm modules (32 layers * 2 + 1 final)
11
+ *
12
+ * H100 Optimizations:
13
+ * - Vectorized loads/stores (__nv_bfloat162/__half2) for maximum memory bandwidth
14
+ * - Warp shuffle reductions (no shared memory bank conflicts)
15
+ * - Coalesced memory access patterns
16
+ * - Block size tuned for 132 SMs
17
+ */
18
+
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_bf16.h>
22
+ #include <cmath>
23
+
24
+ constexpr int WARP_SIZE = 32;
25
+ constexpr int MAX_THREADS = 1024;
26
+
27
+ // Warp-level reduction using shuffle operations
28
+ template <typename T>
29
+ __device__ __forceinline__ T warp_reduce_sum(T val) {
30
+ #pragma unroll
31
+ for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
32
+ val += __shfl_xor_sync(0xffffffff, val, offset);
33
+ }
34
+ return val;
35
+ }
36
+
37
+ // Block-level reduction using shared memory
38
+ template <typename T>
39
+ __device__ __forceinline__ T block_reduce_sum(T val, T* shared) {
40
+ const int lane = threadIdx.x % WARP_SIZE;
41
+ const int wid = threadIdx.x / WARP_SIZE;
42
+
43
+ // Warp-level reduction
44
+ val = warp_reduce_sum(val);
45
+
46
+ // Write warp results to shared memory
47
+ if (lane == 0) {
48
+ shared[wid] = val;
49
+ }
50
+ __syncthreads();
51
+
52
+ // Final reduction in first warp
53
+ const int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
54
+ val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : T(0);
55
+ if (wid == 0) {
56
+ val = warp_reduce_sum(val);
57
+ }
58
+
59
+ return val;
60
+ }
61
+
62
+ // Helper functions for type conversion
63
+ __device__ __forceinline__ float to_float(float x) { return x; }
64
+ __device__ __forceinline__ float to_float(__half x) { return __half2float(x); }
65
+ __device__ __forceinline__ float to_float(__nv_bfloat16 x) { return __bfloat162float(x); }
66
+
67
+ __device__ __forceinline__ float from_float(float x, float*) { return x; }
68
+ __device__ __forceinline__ __half from_float(float x, __half*) { return __float2half(x); }
69
+ __device__ __forceinline__ __nv_bfloat16 from_float(float x, __nv_bfloat16*) { return __float2bfloat16(x); }
70
+
71
+ // ============================================================================
72
+ // BF16-specific optimized kernel using __nv_bfloat162 for 2-element vectorization
73
+ // Optimized for Qwen3 hidden_size=4096 (even, >= 64)
74
+ // ============================================================================
75
+ __global__ void rmsnorm_kernel_bf16_vectorized(
76
+ __nv_bfloat16* __restrict__ output,
77
+ const __nv_bfloat16* __restrict__ input,
78
+ const __nv_bfloat16* __restrict__ weight,
79
+ const int hidden_size,
80
+ const float eps
81
+ ) {
82
+ extern __shared__ char smem[];
83
+ float* shared = reinterpret_cast<float*>(smem);
84
+
85
+ const int row = blockIdx.x;
86
+ const int tid = threadIdx.x;
87
+ const int stride = blockDim.x;
88
+
89
+ const __nv_bfloat16* row_input = input + row * hidden_size;
90
+ __nv_bfloat16* row_output = output + row * hidden_size;
91
+
92
+ // Phase 1: Compute sum of squares with bf16x2 vectorized loads
93
+ float sum_sq = 0.0f;
94
+
95
+ // Use __nv_bfloat162 for 2-element vectorized loads
96
+ const int vec_hidden = hidden_size / 2;
97
+ const __nv_bfloat162* vec_input = reinterpret_cast<const __nv_bfloat162*>(row_input);
98
+
99
+ #pragma unroll 4
100
+ for (int i = tid; i < vec_hidden; i += stride) {
101
+ __nv_bfloat162 v = vec_input[i];
102
+ float v0 = __bfloat162float(v.x);
103
+ float v1 = __bfloat162float(v.y);
104
+ sum_sq += v0 * v0 + v1 * v1;
105
+ }
106
+
107
+ // Handle odd element if hidden_size is odd (not the case for Qwen3)
108
+ if (hidden_size % 2 == 1 && tid == 0) {
109
+ float v = __bfloat162float(row_input[hidden_size - 1]);
110
+ sum_sq += v * v;
111
+ }
112
+
113
+ // Reduce across block
114
+ sum_sq = block_reduce_sum(sum_sq, shared);
115
+
116
+ // Compute RMS inverse
117
+ __shared__ float rms_inv;
118
+ if (tid == 0) {
119
+ float mean_sq = sum_sq / static_cast<float>(hidden_size);
120
+ rms_inv = rsqrtf(mean_sq + eps);
121
+ }
122
+ __syncthreads();
123
+
124
+ const float factor = rms_inv;
125
+
126
+ // Phase 2: Apply normalization and weight with bf16x2 vectorized stores
127
+ const __nv_bfloat162* vec_weight = reinterpret_cast<const __nv_bfloat162*>(weight);
128
+ __nv_bfloat162* vec_output = reinterpret_cast<__nv_bfloat162*>(row_output);
129
+
130
+ #pragma unroll 4
131
+ for (int i = tid; i < vec_hidden; i += stride) {
132
+ __nv_bfloat162 v_in = vec_input[i];
133
+ __nv_bfloat162 v_w = vec_weight[i];
134
+
135
+ float v0 = __bfloat162float(v_in.x);
136
+ float v1 = __bfloat162float(v_in.y);
137
+ float w0 = __bfloat162float(v_w.x);
138
+ float w1 = __bfloat162float(v_w.y);
139
+
140
+ __nv_bfloat162 result;
141
+ result.x = __float2bfloat16(v0 * factor * w0);
142
+ result.y = __float2bfloat16(v1 * factor * w1);
143
+ vec_output[i] = result;
144
+ }
145
+
146
+ // Handle odd element
147
+ if (hidden_size % 2 == 1 && tid == 0) {
148
+ float v = __bfloat162float(row_input[hidden_size - 1]);
149
+ float w = __bfloat162float(weight[hidden_size - 1]);
150
+ row_output[hidden_size - 1] = __float2bfloat16(v * factor * w);
151
+ }
152
+ }
153
+
154
+ // ============================================================================
155
+ // FP16-specific optimized kernel using __half2 for 2-element vectorization
156
+ // ============================================================================
157
+ __global__ void rmsnorm_kernel_fp16_vectorized(
158
+ __half* __restrict__ output,
159
+ const __half* __restrict__ input,
160
+ const __half* __restrict__ weight,
161
+ const int hidden_size,
162
+ const float eps
163
+ ) {
164
+ extern __shared__ char smem[];
165
+ float* shared = reinterpret_cast<float*>(smem);
166
+
167
+ const int row = blockIdx.x;
168
+ const int tid = threadIdx.x;
169
+ const int stride = blockDim.x;
170
+
171
+ const __half* row_input = input + row * hidden_size;
172
+ __half* row_output = output + row * hidden_size;
173
+
174
+ // Phase 1: Compute sum of squares with half2 vectorized loads
175
+ float sum_sq = 0.0f;
176
+
177
+ const int vec_hidden = hidden_size / 2;
178
+ const __half2* vec_input = reinterpret_cast<const __half2*>(row_input);
179
+
180
+ #pragma unroll 4
181
+ for (int i = tid; i < vec_hidden; i += stride) {
182
+ __half2 v = vec_input[i];
183
+ float v0 = __half2float(v.x);
184
+ float v1 = __half2float(v.y);
185
+ sum_sq += v0 * v0 + v1 * v1;
186
+ }
187
+
188
+ // Handle odd element if hidden_size is odd
189
+ if (hidden_size % 2 == 1 && tid == 0) {
190
+ float v = __half2float(row_input[hidden_size - 1]);
191
+ sum_sq += v * v;
192
+ }
193
+
194
+ // Reduce across block
195
+ sum_sq = block_reduce_sum(sum_sq, shared);
196
+
197
+ // Compute RMS inverse
198
+ __shared__ float rms_inv;
199
+ if (tid == 0) {
200
+ float mean_sq = sum_sq / static_cast<float>(hidden_size);
201
+ rms_inv = rsqrtf(mean_sq + eps);
202
+ }
203
+ __syncthreads();
204
+
205
+ const float factor = rms_inv;
206
+
207
+ // Phase 2: Apply normalization with half2 vectorized stores
208
+ const __half2* vec_weight = reinterpret_cast<const __half2*>(weight);
209
+ __half2* vec_output = reinterpret_cast<__half2*>(row_output);
210
+
211
+ #pragma unroll 4
212
+ for (int i = tid; i < vec_hidden; i += stride) {
213
+ __half2 v_in = vec_input[i];
214
+ __half2 v_w = vec_weight[i];
215
+
216
+ float v0 = __half2float(v_in.x);
217
+ float v1 = __half2float(v_in.y);
218
+ float w0 = __half2float(v_w.x);
219
+ float w1 = __half2float(v_w.y);
220
+
221
+ __half2 result;
222
+ result.x = __float2half(v0 * factor * w0);
223
+ result.y = __float2half(v1 * factor * w1);
224
+ vec_output[i] = result;
225
+ }
226
+
227
+ // Handle odd element
228
+ if (hidden_size % 2 == 1 && tid == 0) {
229
+ float v = __half2float(row_input[hidden_size - 1]);
230
+ float w = __half2float(weight[hidden_size - 1]);
231
+ row_output[hidden_size - 1] = __float2half(v * factor * w);
232
+ }
233
+ }
234
+
235
+ // ============================================================================
236
+ // Generic scalar kernel (fallback)
237
+ // ============================================================================
238
+ template <typename scalar_t, typename acc_t = float>
239
+ __global__ void rmsnorm_kernel(
240
+ scalar_t* __restrict__ output,
241
+ const scalar_t* __restrict__ input,
242
+ const scalar_t* __restrict__ weight,
243
+ const int hidden_size,
244
+ const float eps
245
+ ) {
246
+ extern __shared__ char smem[];
247
+ acc_t* shared = reinterpret_cast<acc_t*>(smem);
248
+
249
+ const int row = blockIdx.x;
250
+ const int tid = threadIdx.x;
251
+ const int stride = blockDim.x;
252
+
253
+ const scalar_t* row_input = input + row * hidden_size;
254
+ scalar_t* row_output = output + row * hidden_size;
255
+
256
+ // Compute sum of squares
257
+ acc_t sum_sq = 0.0f;
258
+ for (int i = tid; i < hidden_size; i += stride) {
259
+ acc_t val = to_float(row_input[i]);
260
+ sum_sq += val * val;
261
+ }
262
+
263
+ // Reduce across block
264
+ sum_sq = block_reduce_sum(sum_sq, shared);
265
+
266
+ // Compute RMS
267
+ __shared__ acc_t rms_inv;
268
+ if (tid == 0) {
269
+ acc_t mean_sq = sum_sq / static_cast<acc_t>(hidden_size);
270
+ rms_inv = rsqrtf(mean_sq + eps);
271
+ }
272
+ __syncthreads();
273
+
274
+ // Apply normalization and weight
275
+ for (int i = tid; i < hidden_size; i += stride) {
276
+ acc_t val = to_float(row_input[i]);
277
+ acc_t w = to_float(weight[i]);
278
+ row_output[i] = from_float(val * rms_inv * w, (scalar_t*)nullptr);
279
+ }
280
+ }
281
+
282
+ // ============================================================================
283
+ // Launch functions
284
+ // ============================================================================
285
+ extern "C" {
286
+
287
+ void rmsnorm_forward_fp16(
288
+ __half* output,
289
+ const __half* input,
290
+ const __half* weight,
291
+ const int batch_size,
292
+ const int seq_len,
293
+ const int hidden_size,
294
+ const float eps,
295
+ cudaStream_t stream
296
+ ) {
297
+ const int num_rows = batch_size * seq_len;
298
+ int threads = min(hidden_size / 2, MAX_THREADS);
299
+ threads = max(threads, WARP_SIZE);
300
+ threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
301
+
302
+ size_t smem_size = ((threads + WARP_SIZE - 1) / WARP_SIZE) * sizeof(float);
303
+
304
+ if (hidden_size % 2 == 0 && hidden_size >= 64) {
305
+ rmsnorm_kernel_fp16_vectorized<<<num_rows, threads, smem_size, stream>>>(
306
+ output, input, weight, hidden_size, eps
307
+ );
308
+ } else {
309
+ threads = min(hidden_size, MAX_THREADS);
310
+ threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
311
+ rmsnorm_kernel<__half><<<num_rows, threads, smem_size, stream>>>(
312
+ output, input, weight, hidden_size, eps
313
+ );
314
+ }
315
+ }
316
+
317
+ void rmsnorm_forward_bf16(
318
+ __nv_bfloat16* output,
319
+ const __nv_bfloat16* input,
320
+ const __nv_bfloat16* weight,
321
+ const int batch_size,
322
+ const int seq_len,
323
+ const int hidden_size,
324
+ const float eps,
325
+ cudaStream_t stream
326
+ ) {
327
+ const int num_rows = batch_size * seq_len;
328
+ int threads = min(hidden_size / 2, MAX_THREADS);
329
+ threads = max(threads, WARP_SIZE);
330
+ threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
331
+
332
+ size_t smem_size = ((threads + WARP_SIZE - 1) / WARP_SIZE) * sizeof(float);
333
+
334
+ if (hidden_size % 2 == 0 && hidden_size >= 64) {
335
+ rmsnorm_kernel_bf16_vectorized<<<num_rows, threads, smem_size, stream>>>(
336
+ output, input, weight, hidden_size, eps
337
+ );
338
+ } else {
339
+ threads = min(hidden_size, MAX_THREADS);
340
+ threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
341
+ rmsnorm_kernel<__nv_bfloat16><<<num_rows, threads, smem_size, stream>>>(
342
+ output, input, weight, hidden_size, eps
343
+ );
344
+ }
345
+ }
346
+
347
+ void rmsnorm_forward_fp32(
348
+ float* output,
349
+ const float* input,
350
+ const float* weight,
351
+ const int batch_size,
352
+ const int seq_len,
353
+ const int hidden_size,
354
+ const float eps,
355
+ cudaStream_t stream
356
+ ) {
357
+ const int num_rows = batch_size * seq_len;
358
+ int threads = min(hidden_size, MAX_THREADS);
359
+ threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
360
+
361
+ size_t smem_size = ((threads + WARP_SIZE - 1) / WARP_SIZE) * sizeof(float);
362
+
363
+ rmsnorm_kernel<float><<<num_rows, threads, smem_size, stream>>>(
364
+ output, input, weight, hidden_size, eps
365
+ );
366
+ }
367
+
368
+ } // extern "C"
369
+