medmekk commited on
Commit
9567a9b
·
verified ·
1 Parent(s): a754cc8

Delete gemm_kernel_legacy.h

Browse files
Files changed (1) hide show
  1. gemm_kernel_legacy.h +0 -377
gemm_kernel_legacy.h DELETED
@@ -1,377 +0,0 @@
1
- // Legacy version of gemm kernel, support all shape and various value of parameters (BM, BN, BK, etc.)
2
- // It has been replace with faster pipeline version.
3
- #pragma once
4
- #include <cstdio>
5
- #include "../include/gpu_libs.h"
6
- #include "../include/gpu_types.h"
7
- #include "../src/utils/arithmetic.h"
8
- #include "../include/clangd_workaround.h"
9
- #include <cstdlib>
10
- #include <cfloat>
11
-
12
- DEVICE_CODE_BELOW
13
- namespace gemm_kernel_legacy {
14
-
15
-
16
-
17
- template <typename data_type, int BATCH_SIZE>
18
- __device__ inline void read_batch(data_type *dst, const data_type *src) {
19
- if constexpr ((sizeof(data_type) * BATCH_SIZE) == 2 * sizeof(ulong4)) {
20
- *(reinterpret_cast<ulong4 *>(dst) + 0) = *(reinterpret_cast<const ulong4 *>(src) + 0);
21
- *(reinterpret_cast<ulong4 *>(dst) + 1) = *(reinterpret_cast<const ulong4 *>(src) + 1);
22
- } else if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
23
- *reinterpret_cast<ulong4 *>(dst) = *reinterpret_cast<const ulong4 *>(src);
24
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
25
- *reinterpret_cast<ulong2 *>(dst) = *reinterpret_cast<const ulong2 *>(src);
26
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
27
- *reinterpret_cast<ulong1 *>(dst) = *reinterpret_cast<const ulong1 *>(src);
28
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
29
- *reinterpret_cast<uint1 *>(dst) = *reinterpret_cast<const uint1 *>(src);
30
- } else {
31
- #pragma unroll
32
- for (int b = 0; b < BATCH_SIZE; ++b) {
33
- dst[b] = src[b];
34
- }
35
- }
36
- }
37
-
38
- template <typename data_type, int BATCH_SIZE>
39
- __device__ inline void zero_batch(data_type *dst) {
40
- if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
41
- *reinterpret_cast<ulong4 *>(dst) = ulong4{};
42
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
43
- *reinterpret_cast<ulong2 *>(dst) = ulong2{};
44
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
45
- *reinterpret_cast<ulong1 *>(dst) = ulong1{};
46
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
47
- *reinterpret_cast<uint *>(dst) = uint{};
48
- } else {
49
- #pragma unroll
50
- for (int b = 0; b < BATCH_SIZE; ++b) {
51
- dst[b] = 0;
52
- }
53
- }
54
- }
55
-
56
- template <typename data_type, int DST_Y, int DST_X, int SRC_Y, int SRC_X, int BLOCK_DIM, int BATCH_SIZE>
57
- __device__ inline void load_input(data_type dst[DST_Y][DST_X], const data_type src[SRC_Y][SRC_X],
58
- const int begin_x, const int begin_y) {
59
- static_assert(BATCH_SIZE > 0);
60
- /**
61
- Consider (SRC_X % DST_X == 0) && (SRC_Y % DST_Y == 0)
62
- Step 1:
63
- [ ][***][ ][ ]
64
- [ ][ ][ ][ ]
65
- [ ][ ][ ][ ]
66
- [ ][ ][ ][ ]
67
- Step 2:
68
- [ ][ ][ ][ ]
69
- [ ][***][ ][ ]
70
- [ ][ ][ ][ ]
71
- [ ][ ][ ][ ]
72
- */
73
- static_assert((SRC_X % BATCH_SIZE == 0) && (SRC_Y % BATCH_SIZE == 0));
74
- static_assert((DST_X % BATCH_SIZE == 0) && (DST_Y % BATCH_SIZE == 0));
75
- static_assert(BATCH_SIZE <= DST_X && DST_X % BATCH_SIZE == 0);
76
- const int begin_idx = threadIdx.x * BATCH_SIZE;
77
- const constexpr int total_elements = DST_X * DST_Y;
78
- const constexpr int elements_per_step = BLOCK_DIM * BATCH_SIZE;
79
- // FIXME: loop unrolling
80
- #pragma unroll
81
- for (int k = begin_idx; k < total_elements; k += elements_per_step) {
82
- int l_kx = k % DST_X;
83
- int l_ky = k / DST_X;
84
- int g_kx = l_kx + begin_x;
85
- int g_ky = l_ky + begin_y;
86
- auto *dst_flatten = &dst[l_ky][l_kx];
87
- // const auto *src_flatten = &src[g_ky][g_kx];
88
- // read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
89
- if (((SRC_X % DST_X == 0) || (g_kx < SRC_X)) && ((SRC_Y % DST_Y == 0) || (g_ky < SRC_Y))) {
90
- const auto *src_flatten = &src[g_ky][g_kx];
91
- read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
92
- } else {
93
- zero_batch<data_type, BATCH_SIZE>(dst_flatten);
94
- }
95
- }
96
- }
97
-
98
- template <int PM, int PN, int QM, int QN, int QK, int QUANT_SIZE, int BLOCK_SIZE, int BATCH_SIZE>
99
- __device__ void load_scale(float s_s[PM][PN], const float sa[QK][QM], const float sb[QK][QN],
100
- const int m, const int n, const int k) {
101
- constexpr int total_elements = PM * PN;
102
- constexpr int elements_per_step = BLOCK_SIZE * BATCH_SIZE;
103
- // static_assert(PN % BATCH_SIZE)
104
-
105
- const int begin_idx = threadIdx.x * BATCH_SIZE;
106
- #pragma unroll
107
- for (int idx = begin_idx; idx < total_elements; idx += elements_per_step) {
108
- static_assert(BATCH_SIZE == 1);
109
- int i = idx / PN;
110
- int j = idx % PN;
111
- if (((QM % PM == 0) || (m + i < QM)) && ((QN % PN == 0) || ((n + j) / QUANT_SIZE < QN))) {
112
- s_s[i][j] = sa[k / QUANT_SIZE][(m + i)] * sb[k / QUANT_SIZE][(n) / QUANT_SIZE + j];
113
- } else {
114
- s_s[i][j] = 1.0f;
115
- }
116
- }
117
-
118
- }
119
-
120
- template <typename in_data_type, typename acc_data_type,
121
- typename FragC, typename FragA, typename FragB,
122
- int PM, int PN,
123
- int BM, int BN, int BK,
124
- int FRAG_M, int FRAG_N, int FRAG_K,
125
- int WMMA_M, int WMMA_N, int WMMA_K,
126
- int WARP_M, int WARP_N,
127
- int BLOCK_SIZE, int BATCH_SIZE, int QUANT_SIZE>
128
- __device__ void wmma_compute(
129
- const in_data_type s_a[BK][BM],
130
- const in_data_type s_b[BK][BN],
131
- const float s_s[PM][PN],
132
- FragC frag_r[FRAG_M][FRAG_N],
133
- const int comp_c_frag_m,
134
- const int comp_c_frag_n
135
- ) {
136
- FragA frag_a[FRAG_K][FRAG_M];
137
- FragB frag_b[FRAG_K][FRAG_N];
138
-
139
- // Spilt k over BK
140
- for (int k = 0; k < FRAG_K; ++k) {
141
- #pragma unroll
142
- for (int i = 0; i < FRAG_M; ++i) {
143
- int s_a_row = k * WMMA_K;
144
- int s_a_col = (comp_c_frag_m * FRAG_M + i) * WMMA_M;
145
- wmma::load_matrix_sync(frag_a[k][i], &s_a[s_a_row][s_a_col], BM);
146
- }
147
- #pragma unroll
148
- for (int j = 0; j < FRAG_N; ++j) {
149
- int s_b_row = k * WMMA_K;
150
- int s_b_col = (comp_c_frag_n * FRAG_N + j) * WMMA_N;
151
- wmma::load_matrix_sync(frag_b[k][j], &s_b[s_b_row][s_b_col], BN);
152
- }
153
- }
154
-
155
- #pragma unroll
156
- for (int i = 0; i < FRAG_M; i++) {
157
- #pragma unroll
158
- for (int j = 0; j < FRAG_N; j++) {
159
- FragC frag_c;
160
- wmma::fill_fragment(frag_c, 0.0F);
161
- #pragma unroll
162
- for (int k = 0; k < FRAG_K; ++k) {
163
- wmma::mma_sync(frag_c, frag_a[k][i], frag_b[k][j], frag_c);
164
- }
165
- #pragma unroll
166
- for (int k = 0; k < FragC::num_elements; ++k) {
167
- #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
168
- int m = ((threadIdx.x & 16) >> 1) | (k & 7) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
169
- #else // CDNA3, WAVE_SIZE = 64
170
- int m = ((threadIdx.x & 48) >> 2) | (k & 3) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
171
- #endif
172
- int n = ((threadIdx.x & 15) | (comp_c_frag_n * FRAG_N + j) * WMMA_N) / QUANT_SIZE;
173
- float scale = s_s[m][n];
174
- frag_r[i][j].x[k] += (acc_data_type)scale * (acc_data_type)frag_c.x[k];
175
- }
176
- }
177
- }
178
- }
179
-
180
-
181
- template <typename acc_data_type, typename out_data_type,
182
- typename FragC, typename FragOut, int WMMA_M, int WMMA_N,
183
- int BM, int BN, int M, int N, int FRAG_M, int FRAG_N>
184
- __device__ inline void store_result(
185
- out_data_type c[M][N],
186
- FragC frag_r[FRAG_M][FRAG_N],
187
- const int m,
188
- const int n,
189
- const int comp_c_frag_m,
190
- const int comp_c_frag_n
191
- ) {
192
- #pragma unroll
193
- for (int i = 0; i < FRAG_M; i++) {
194
- #pragma unroll
195
- for (int j = 0; j < FRAG_N; j++) {
196
- int frag_m = comp_c_frag_m * FRAG_M + i;
197
- int frag_n = comp_c_frag_n * FRAG_N + j;
198
- int row = m + frag_m * WMMA_M;
199
- int col = n + frag_n * WMMA_N;
200
- if (((M % BM == 0) || (row < M)) && ((N % BN == 0) || (col < N))) {
201
- out_data_type *c_ptr = &c[row][col];
202
- if constexpr (sizeof(acc_data_type) == sizeof(out_data_type)) {
203
- wmma::store_matrix_sync(reinterpret_cast<out_data_type*>(c_ptr), frag_r[i][j], N, wmma::mem_row_major);
204
- } else if constexpr (sizeof(out_data_type) == sizeof(half)) {
205
- FragOut frag_out;
206
- static_assert(sizeof(half) == sizeof(out_data_type));
207
- static_assert(FragOut::num_elements == FragC::num_elements);
208
- for (int k=0;k<FragOut::num_elements;++k) {
209
- __hip_bfloat16 reg = frag_r[i][j].x[k];
210
- frag_out.x[k] = *reinterpret_cast<half*>(&reg);
211
- }
212
- wmma::store_matrix_sync(reinterpret_cast<half*>(c_ptr), frag_out, N, wmma::mem_row_major);
213
- } else {
214
- static_assert(0, "Unsupported data type for output");
215
- }
216
-
217
- }
218
- }
219
- }
220
- }
221
-
222
- // a dummy template to allow inlcuding this file
223
- template<int Dummy=0>
224
- __global__ void reduce(uint32_t m, uint32_t n, uint32_t splitk, const float *c_splitk, __hip_bfloat16 *c) {
225
- auto tid = blockIdx.x * blockDim.x + threadIdx.x;
226
- if (tid >= m * n) {
227
- return;
228
- }
229
- float sum = 0;
230
- for (auto i = 0; i < splitk; ++i) {
231
- sum += c_splitk[i * (m * n) + tid];
232
- }
233
- c[tid] = sum;
234
- }
235
-
236
-
237
- #ifdef PARAMETERIZE_LIBRARY
238
- template <
239
- typename in_data_type,
240
- typename acc_data_type, // Accumulator type (e.g., float)
241
- typename out_data_type, // Output type (e.g., __hip_bfloat16)
242
- int M, int N, int K, // Matrix dimensions
243
- int BM, int BN, int BK, // Tile dimensions
244
- int QUANT_SIZE, // Quantization block size
245
- int BLOCK_SIZE, // Block size
246
- int WARP_M, int WARP_N // Warp dimensions
247
- >
248
- #else
249
- using in_data_type = __FP8_TYPE;
250
- using out_data_type = __BF16_TYPE;
251
- using acc_data_type = float;
252
- // constexpr int M = 4096, N = 4096, K = 4096;
253
- constexpr int M = 96, N = 1024, K = 1024;
254
- // constexpr int M = 512, N = 512, K = 512;
255
- constexpr int BM = 64, BN = 256, BK = 32;
256
- constexpr int QUANT_SIZE = 128, BLOCK_SIZE = 256;
257
- #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
258
- constexpr int WARP_M = 4, WARP_N = 2;
259
- #else // CDNA3, WAVE_SIZE = 64
260
- constexpr int WARP_M = 2, WARP_N = 2;
261
- #endif
262
- #endif // End of parameterization
263
- __global__ void gemm_kernel(
264
- const in_data_type a[K][M],
265
- const in_data_type b[K][N],
266
- out_data_type c[M][N],
267
- const float sa[ceil_div(K, QUANT_SIZE)][M / 1 ], // Assuming M is divisible by 1 (always true)
268
- const float sb[ceil_div(K, QUANT_SIZE)][ceil_div(N, QUANT_SIZE)]
269
- ) {
270
- // --- Start: Derived parameters and constants ---
271
- constexpr int WMMA_M = 16; // Fixed WMMA dimension M
272
- constexpr int WMMA_N = 16; // Fixed WMMA dimension N
273
- constexpr int WMMA_K = 32; // Fixed WMMA dimension K (for FP8)
274
-
275
- // WARP_M/N define the 2D arrangement of warps in the block grid.
276
- // These might need adjustment based on BLOCK_DIM_X/Y strategy.
277
- // Using fixed values based on the non-parameterized version for now.
278
- // TODO: Derive WARP_M/N from BLOCK_DIM_X/Y if a flexible strategy is needed.
279
- constexpr int WARP_NUM = WARP_M * WARP_N; // Total warps per block
280
-
281
- // Assertion: Check if the assumed warp layout matches the block size
282
- static_assert(WARP_NUM * WAVE_SIZE == BLOCK_SIZE, "WARP_M * WARP_N * WAVE_SIZE must equal BLOCK_SIZE");
283
-
284
- // Fragments per warp
285
- constexpr int FRAG_M_PER_WARP = BM / WMMA_M / WARP_M;
286
- constexpr int FRAG_N_PER_WARP = BN / WMMA_N / WARP_N;
287
- constexpr int FRAG_K = BK / WMMA_K; // Fragments along K dimension tile
288
-
289
- static_assert(BM % (WMMA_M * WARP_M) == 0, "BM must be divisible by WMMA_M * WARP_M");
290
- static_assert(BN % (WMMA_N * WARP_N) == 0, "BN must be divisible by WMMA_N * WARP_N");
291
- static_assert(BK % WMMA_K == 0, "BK must be divisible by WMMA_K");
292
- static_assert(BK >= 32, "BK must be at least 32");
293
- // --- End: Derived parameters and constants ---
294
-
295
- constexpr int QM = M; // Dimension M for scale A
296
- constexpr int QN = ceil_div(N, QUANT_SIZE); // Dimension N for scale B (quantized)
297
- constexpr int QK = ceil_div(K, QUANT_SIZE); // Dimension K for scales (quantized)
298
- constexpr int PM = BM; // Block size M for scale A * B
299
- constexpr int PN = ceil_div(BN, QUANT_SIZE); // Block size N for scale A * B
300
-
301
- // Ensure derived fragment counts are positive
302
- static_assert(FRAG_M_PER_WARP > 0, "FRAG_M_PER_WARP must be positive");
303
- static_assert(FRAG_N_PER_WARP > 0, "FRAG_N_PER_WARP must be positive");
304
- static_assert(FRAG_K > 0, "FRAG_K must be positive");
305
-
306
- using FragA = wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::col_major>;
307
- using FragB = wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::row_major>;
308
- using FragC = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, acc_data_type>;
309
- using FragOut = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half>; // Output uses half for storage via bfloat16 reinterpret
310
-
311
- __shared__ in_data_type s_a[BK][BM];
312
- __shared__ in_data_type s_b[BK][BN];
313
- __shared__ acc_data_type s_s[PM][PN]; // Accumulator type for scales
314
- FragC frag_r[FRAG_M_PER_WARP][FRAG_N_PER_WARP]; // Accumulator fragments
315
-
316
- // handle splitk
317
- a += blockIdx.z * K;
318
- b += blockIdx.z * K;
319
- c += blockIdx.z * M;
320
- sa += blockIdx.z * QK;
321
- sb += blockIdx.z * QK;
322
-
323
- int tid = threadIdx.x; // Linear thread ID within the block
324
- int wid = tid / WAVE_SIZE; // Warp ID within the block
325
-
326
- // Initialize the output accumulator fragments to zero
327
- #pragma unroll
328
- for (int i = 0; i < FRAG_M_PER_WARP; i++) {
329
- #pragma unroll
330
- for (int j = 0; j < FRAG_N_PER_WARP; j++) {
331
- wmma::fill_fragment(frag_r[i][j], 0.0f); // Use float literal
332
- }
333
- }
334
-
335
- // Spilt and compute fragments
336
- constexpr int iteration_over_k = ceil_div(K, BK); // Use ceil_div for potentially non-divisible K
337
- constexpr int LOAD_BATCH_SIZE = (2 * sizeof(float4) / sizeof(in_data_type)) > 0 ? (2 * sizeof(float4) / sizeof(in_data_type)) : 1; // Ensure batch size > 0
338
- static_assert(LOAD_BATCH_SIZE > 0, "LOAD_BATCH_SIZE must be positive");
339
-
340
- for (int bk = 0; bk < iteration_over_k; bk++) {
341
- const int m = blockIdx.y * BM;
342
- const int n = blockIdx.x * BN;
343
- const int k = bk * BK;
344
-
345
- // Calculate remaining K for boundary checks if needed (not currently used by load_input)
346
- // const int k_rem = K - k;
347
-
348
- // Load data into shared memory
349
- load_input<in_data_type, BK, BM, K, M, BLOCK_SIZE, LOAD_BATCH_SIZE>(
350
- s_a, a, m, k);
351
- load_input<in_data_type, BK, BN, K, N, BLOCK_SIZE, LOAD_BATCH_SIZE>(
352
- s_b, b, n, k);
353
- // Load scales into shared memory (using acc_data_type for s_s)
354
- load_scale<PM, PN, QM, QN, QK, QUANT_SIZE, BLOCK_SIZE, 1>(
355
- s_s, sa, sb, m, n, k);
356
- __syncthreads();
357
-
358
- // Perform matrix multiplication using WMMA
359
- wmma_compute<in_data_type, acc_data_type, FragC, FragA, FragB,
360
- PM, PN, BM, BN, BK, FRAG_M_PER_WARP, FRAG_N_PER_WARP, FRAG_K,
361
- WMMA_M, WMMA_N, WMMA_K,
362
- WARP_M, WARP_N,
363
- BLOCK_SIZE, LOAD_BATCH_SIZE, QUANT_SIZE>( // Pass calculated BLOCK_SIZE and LOAD_BATCH_SIZE
364
- s_a, s_b, s_s, frag_r, wid / WARP_N, wid % WARP_N);
365
- __syncthreads();
366
- }
367
- // Store results from accumulator fragments to global memory
368
- store_result<acc_data_type, out_data_type, FragC, FragOut,
369
- WMMA_M, WMMA_N, BM, BN, M, N, FRAG_M_PER_WARP, FRAG_N_PER_WARP>(
370
- c, frag_r, blockIdx.y * BM, blockIdx.x * BN,
371
- wid / WARP_N, wid % WARP_N);
372
-
373
-
374
- };
375
-
376
-
377
- }; // namespace gemm_kernel_legacy