medmekk commited on
Commit
a754cc8
·
verified ·
1 Parent(s): 3370f0d

Delete gemm_kernel.h

Browse files
Files changed (1) hide show
  1. gemm_kernel.h +0 -896
gemm_kernel.h DELETED
@@ -1,896 +0,0 @@
1
- // Pipeline GEMM kernel. This version is rushed written and may not applied to all shape.
2
- // Currently, only selected parameters is tested. (See gemm_launcher )
3
- #ifndef GEMM_KERNEL
4
- #define GEMM_KERNEL
5
-
6
- #include <cstdio>
7
- #include <hip/amd_detail/amd_hip_runtime.h>
8
- #include <hip/amd_detail/amd_warp_functions.h>
9
- #pragma clang diagnostic push
10
- #pragma clang diagnostic ignored "-Wunknown-attributes"
11
- #include "../include/gpu_libs.h"
12
- #include "../include/gpu_types.h"
13
- #include "../src/utils/arithmetic.h"
14
- #include "../include/clangd_workaround.h"
15
- #include <cstdlib>
16
- #include <cfloat>
17
-
18
- namespace gemm_kernel {
19
-
20
- template <typename data_type, int BATCH_SIZE> __device__ inline void read_batch(data_type *dst, const data_type *src) {
21
- if constexpr ((sizeof(data_type) * BATCH_SIZE) == 2 * sizeof(ulong4)) {
22
- *(reinterpret_cast<ulong4 *>(dst) + 0) = *(reinterpret_cast<const ulong4 *>(src) + 0);
23
- *(reinterpret_cast<ulong4 *>(dst) + 1) = *(reinterpret_cast<const ulong4 *>(src) + 1);
24
- } else if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
25
- *reinterpret_cast<ulong4 *>(dst) = *reinterpret_cast<const ulong4 *>(src);
26
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
27
- *reinterpret_cast<ulong2 *>(dst) = *reinterpret_cast<const ulong2 *>(src);
28
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
29
- *reinterpret_cast<ulong1 *>(dst) = *reinterpret_cast<const ulong1 *>(src);
30
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
31
- *reinterpret_cast<uint1 *>(dst) = *reinterpret_cast<const uint1 *>(src);
32
- } else {
33
- #pragma unroll
34
- for (int b = 0; b < BATCH_SIZE; ++b) {
35
- dst[b] = src[b];
36
- }
37
- }
38
- }
39
-
40
- template <typename data_type, int BATCH_SIZE> __device__ inline void zero_batch(data_type *dst) {
41
- if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
42
- *reinterpret_cast<ulong4 *>(dst) = ulong4{};
43
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
44
- *reinterpret_cast<ulong2 *>(dst) = ulong2{};
45
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
46
- *reinterpret_cast<ulong1 *>(dst) = ulong1{};
47
- } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
48
- *reinterpret_cast<uint *>(dst) = uint{};
49
- } else {
50
- #pragma unroll
51
- for (int b = 0; b < BATCH_SIZE; ++b) {
52
- dst[b] = 0;
53
- }
54
- }
55
- }
56
-
57
- template <typename data_type, int DST_Y, int DST_X, int SRC_Y, int SRC_X, int BLOCK_DIM, int BATCH_SIZE>
58
- __device__ inline void load_input(data_type dst[DST_Y][DST_X], const data_type src[SRC_Y][SRC_X], const int begin_x,
59
- const int begin_y) {
60
- static_assert(BATCH_SIZE > 0);
61
- /**
62
- Consider (SRC_X % DST_X == 0) && (SRC_Y % DST_Y == 0)
63
- Step 1:
64
- [ ][***][ ][ ]
65
- [ ][ ][ ][ ]
66
- [ ][ ][ ][ ]
67
- [ ][ ][ ][ ]
68
- Step 2:
69
- [ ][ ][ ][ ]
70
- [ ][***][ ][ ]
71
- [ ][ ][ ][ ]
72
- [ ][ ][ ][ ]
73
- */
74
- static_assert((SRC_X % BATCH_SIZE == 0) && (SRC_Y % BATCH_SIZE == 0));
75
- static_assert((DST_X % BATCH_SIZE == 0) && (DST_Y % BATCH_SIZE == 0));
76
- static_assert(BATCH_SIZE <= DST_X && DST_X % BATCH_SIZE == 0);
77
- const int begin_idx = threadIdx.x * BATCH_SIZE;
78
- const constexpr int total_elements = DST_X * DST_Y;
79
- const constexpr int elements_per_step = BLOCK_DIM * BATCH_SIZE;
80
- // FIXME: loop unrolling
81
- #pragma unroll
82
- for (int k = begin_idx; k < total_elements; k += elements_per_step) {
83
- int l_kx = k % DST_X;
84
- int l_ky = k / DST_X;
85
- int g_kx = l_kx + begin_x;
86
- int g_ky = l_ky + begin_y;
87
- auto *dst_flatten = &dst[l_ky][l_kx];
88
- // const auto *src_flatten = &src[g_ky][g_kx];
89
- // read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
90
- if (((SRC_X % DST_X == 0) || (g_kx < SRC_X)) && ((SRC_Y % DST_Y == 0) || (g_ky < SRC_Y))) {
91
- const auto *src_flatten = &src[g_ky][g_kx];
92
- read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
93
- } else {
94
- zero_batch<data_type, BATCH_SIZE>(dst_flatten);
95
- }
96
- }
97
- }
98
-
99
- template <int PM, int PN, int QM, int QN, int QK, int QUANT_SIZE, int BLOCK_SIZE, int BATCH_SIZE>
100
- __device__ void load_scale(float s_s[PM][PN], const float sa[QK][QM], const float sb[QK][QN], const int m, const int n,
101
- const int k) {
102
- constexpr int total_elements = PM * PN;
103
- constexpr int elements_per_step = BLOCK_SIZE * BATCH_SIZE;
104
- // static_assert(PN % BATCH_SIZE)
105
-
106
- const int begin_idx = threadIdx.x * BATCH_SIZE;
107
- #pragma unroll
108
- for (int idx = begin_idx; idx < total_elements; idx += elements_per_step) {
109
- static_assert(BATCH_SIZE == 1);
110
- int i = idx / PN;
111
- int j = idx % PN;
112
- if (((QM % PM == 0) || (m + i < QM)) && ((QN % PN == 0) || ((n + j) / QUANT_SIZE < QN))) {
113
- s_s[i][j] = sa[k / QUANT_SIZE][(m + i)] * sb[k / QUANT_SIZE][(n) / QUANT_SIZE + j];
114
- } else {
115
- s_s[i][j] = 1.0f;
116
- }
117
- }
118
- }
119
-
120
- // don't use __builtin_readcyclecounter(), which would insert waitcnt
121
- __device__ auto getclock() {
122
- uint64_t clk;
123
- asm volatile("s_memtime %0" : "=r"(clk));
124
- return clk;
125
- }
126
-
127
-
128
- template <typename Elem> __global__ void check_trans(const Elem *origin, const Elem *tranposed, int m, int n) {
129
- auto x = threadIdx.x + blockIdx.x * blockDim.x;
130
- auto y = threadIdx.y + blockIdx.y * blockDim.y;
131
- if (x < m && y < n) {
132
- if (origin[x * n + y] != tranposed[y * m + x]) {
133
- printf("Error: %d %d\n", x, y);
134
- }
135
- }
136
- }
137
-
138
- template <typename in_data_type, typename acc_data_type, typename FragC, typename FragA, typename FragB, int PM, int PN,
139
- int BM, int BN, int BK, int FRAG_M, int FRAG_N, int FRAG_K, int WMMA_M, int WMMA_N, int WMMA_K, int WARP_M,
140
- int WARP_N, int BLOCK_SIZE, int BATCH_SIZE, int QUANT_SIZE>
141
- __device__ void wmma_compute(const in_data_type s_a[BM][BK + 8], const in_data_type s_b[BN][BK + 8],
142
- const float s_s[PN][PM], FragC frag_r[FRAG_M][FRAG_N], const int comp_c_frag_m,
143
- const int comp_c_frag_n) {
144
- FragC frag_c[FRAG_M][FRAG_N];
145
-
146
- #pragma unroll
147
- for (int i = 0; i < FRAG_M; i++) {
148
- #pragma unroll
149
- for (int j = 0; j < FRAG_N; j++) {
150
- wmma::fill_fragment(frag_c[i][j], 0.0F);
151
- }
152
- }
153
-
154
- #pragma unroll
155
- for (int k = 0; k < FRAG_K; ++k) {
156
- #pragma unroll
157
- for (int i = 0; i < FRAG_M; i++) {
158
- FragA frag_a;
159
- int s_a_row = k * WMMA_K;
160
- int s_a_col = (comp_c_frag_m * FRAG_M + i) * WMMA_M;
161
- wmma::load_matrix_sync(frag_a, &s_a[s_a_col][s_a_row], BK + 8);
162
- #pragma unroll
163
- for (int j = 0; j < FRAG_N; j++) {
164
- FragB frag_b;
165
- int s_b_row = k * WMMA_K;
166
- int s_b_col = (comp_c_frag_n * FRAG_N + j) * WMMA_N;
167
- wmma::load_matrix_sync(frag_b, &s_b[s_b_col][s_b_row], BK + 8);
168
-
169
- wmma::mma_sync(frag_c[i][j], frag_a, frag_b, frag_c[i][j]);
170
- }
171
- }
172
- }
173
- #pragma unroll
174
- for (int i = 0; i < FRAG_M; i++) {
175
- #pragma unroll
176
- for (int j = 0; j < FRAG_N; j++) {
177
- #pragma unroll
178
- for (int k = 0; k < FragC::num_elements; ++k) {
179
- #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
180
- int m = ((threadIdx.x & 16) >> 1) | (k & 7) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
181
- #else // CDNA3, WAVE_SIZE = 64
182
- // int m = ((threadIdx.x & 48) >> 2) | (k & 3) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
183
- #endif
184
- // int n = ((threadIdx.x & 15) | (comp_c_frag_n * FRAG_N + j) * WMMA_N) / QUANT_SIZE;
185
- auto lane = threadIdx.x % 64;
186
- int m, n;
187
- if constexpr (WMMA_M == 32) {
188
- // C or D i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
189
- // C or D j: (lane % 32)
190
- m = (8 * (k / 4) % 32) + 4 * (lane / 32) + (k % 4);
191
- n = lane % 32;
192
- } else {
193
- // C or D i: 4 * floor(lane / 16) + (GPR_num % 4)
194
- // C or D j: (lane % 16)
195
- m = 4 * (lane / 16) + (k % 4);
196
- n = lane % 16;
197
- }
198
- m += (comp_c_frag_m * FRAG_M + i) * WMMA_M;
199
- n += (comp_c_frag_n * FRAG_N + j) * WMMA_N;
200
- n = n / QUANT_SIZE;
201
- // if(threadIdx.x == 192 && blockIdx.x ==0 && blockIdx.y == 0 && blockIdx.z == 0)
202
- // printf("m: %d, n: %d\n", m, n);
203
- float scale = s_s[n][m];
204
- frag_r[i][j].x[k] += (acc_data_type)scale * (acc_data_type)frag_c[i][j].x[k];
205
- }
206
- }
207
- }
208
- }
209
-
210
- __device__ rocwmma::bfloat16_t fast_f32tob16(float f) {
211
- union {
212
- float fp32;
213
- unsigned int u32;
214
- } u = {f};
215
- u.u32 += 0x7fff + ((u.u32 >> 16) & 1);
216
- auto ret = u.u32 >> 16;
217
- return reinterpret_cast<rocwmma::bfloat16_t &>(ret);
218
- }
219
-
220
- template <typename acc_data_type, typename out_data_type, typename FragC, typename FragOut, int WMMA_M, int WMMA_N,
221
- int BM, int BN, int M, int N, int FRAG_M, int FRAG_N>
222
- __device__ inline void store_result(out_data_type c[M][N], FragC frag_r[FRAG_M][FRAG_N], const int m, const int n,
223
- const int comp_c_frag_m, const int comp_c_frag_n) {
224
- #pragma unroll
225
- for (int i = 0; i < FRAG_M; i++) {
226
- #pragma unroll
227
- for (int j = 0; j < FRAG_N; j++) {
228
- int frag_m = comp_c_frag_m * FRAG_M + i;
229
- int frag_n = comp_c_frag_n * FRAG_N + j;
230
- int row = m + frag_m * WMMA_M;
231
- int col = n + frag_n * WMMA_N;
232
- if (((M % BM == 0) || (row < M)) && ((N % BN == 0) || (col < N))) {
233
- out_data_type *c_ptr = &c[row][col];
234
- if constexpr (sizeof(acc_data_type) == sizeof(out_data_type)) { // split_k
235
- auto lane = threadIdx.x % 64;
236
- #pragma unroll
237
- for (int k = 0; k < FragC::num_elements; ++k) {
238
- int m, n;
239
- if constexpr (WMMA_M == 32) {
240
- // C or D i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
241
- // C or D j: (lane % 32)
242
- m = (8 * (k / 4) % 32) + 4 * (lane / 32) + (k % 4);
243
- n = lane % 32;
244
- } else {
245
- // C or D i: 4 * floor(lane / 16) + (GPR_num % 4)
246
- // C or D j: (lane % 16)
247
- m = 4 * (lane / 16) + (k % 4);
248
- n = lane % 16;
249
- }
250
- c_ptr[m * N + n] = frag_r[i][j].x[k];;
251
- }
252
-
253
- // wmma::store_matrix_sync(reinterpret_cast<out_data_type *>(c_ptr), frag_r[i][j], N,
254
- // wmma::mem_row_major);
255
- } else if constexpr (sizeof(out_data_type) == sizeof(half)) {
256
- FragOut frag_out;
257
- static_assert(sizeof(half) == sizeof(out_data_type));
258
- static_assert(FragOut::num_elements == FragC::num_elements);
259
- for (int k = 0; k < FragOut::num_elements; ++k) {
260
- auto reg = fast_f32tob16(frag_r[i][j].x[k]);
261
- frag_out.x[k] = *reinterpret_cast<half *>(&reg);
262
- }
263
- wmma::store_matrix_sync(reinterpret_cast<half *>(c_ptr), frag_out, N, wmma::mem_row_major);
264
- } else {
265
- static_assert(0, "Unsupported data type for output");
266
- }
267
- }
268
- }
269
- }
270
- }
271
-
272
- // a dummy template to allow inlcuding this file
273
- template <int Splitk> __global__ void reduce(uint32_t m, uint32_t n, const float *c_splitk, __hip_bfloat16 *c) {
274
- auto tid = blockIdx.x * blockDim.x + threadIdx.x;
275
- if (tid >= m * n) {
276
- return;
277
- }
278
- float4 sum{};
279
- #pragma unroll
280
- for (auto i = 0; i < Splitk; ++i) {
281
- sum += *(float4 *)&c_splitk[i * (m * n) + tid * 4];
282
- }
283
- auto res =
284
- rocwmma::make_vector(fast_f32tob16(sum.x), fast_f32tob16(sum.y), fast_f32tob16(sum.z), fast_f32tob16(sum.w));
285
- *(decltype(res) *)&c[tid * 4] = res;
286
- }
287
-
288
- template<int M, int N, int SPLITK_FACTOR, int BLOCK_SIZE>
289
- __launch_bounds__(BLOCK_SIZE)
290
- __global__ void reduce_kernel(const float c_splitk[SPLITK_FACTOR][M][N], __hip_bfloat16 c[M][N]) {
291
- auto tid = blockIdx.x * blockDim.x + threadIdx.x;
292
- if (tid >= M * N) {
293
- return;
294
- }
295
- float4 sum{};
296
- #pragma unroll
297
- for (auto i = 0; i < SPLITK_FACTOR; ++i) {
298
- sum += *(float4 *)&reinterpret_cast<const float*>(c_splitk)[i * (M * N) + tid * 4];
299
- }
300
- auto res =
301
- rocwmma::make_vector(fast_f32tob16(sum.x), fast_f32tob16(sum.y), fast_f32tob16(sum.z), fast_f32tob16(sum.w));
302
- *(decltype(res) *)&reinterpret_cast< __BF16_TYPE*>(c)[tid * 4] = res;
303
- }
304
-
305
-
306
- #ifdef PARAMETERIZE_LIBRARY
307
- template <typename in_data_type,
308
- typename acc_data_type, // Accumulator type (e.g., float)
309
- typename out_data_type, // Output type (e.g., __hip_bfloat16)
310
- int M, int N, int K, // Matrix dimensions
311
- int BM, int BN, int BK, // Tile dimensions
312
- int QUANT_SIZE, // Quantization block size
313
- int BLOCK_SIZE, // Block size
314
- int WARP_M, int WARP_N, // Warp dimensions
315
- int LDA, int LDB,
316
- int LOAD_BATCH_SIZE> // Load batch size for vectorized memory operations
317
- #else
318
- using in_data_type = __FP8_TYPE;
319
- using out_data_type = __BF16_TYPE;
320
- using acc_data_type = float;
321
- // constexpr int M = 4096, N = 4096, K = 4096;
322
- constexpr int M = 6144, N = 4608, K = 7168;
323
- constexpr int LDA = K, LDB = K;
324
- // constexpr int M = 512, N = 512, K = 512;
325
- constexpr int BM = 256, BN = 128, BK = 128;
326
- constexpr int QUANT_SIZE = 128, BLOCK_SIZE = 512;
327
- constexpr int LOAD_BATCH_SIZE = 16;
328
- #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
329
- constexpr int WARP_M = 4, WARP_N = 2;
330
- #else // CDNA3, WAVE_SIZE = 64
331
- constexpr int WARP_M = 4, WARP_N = 2;
332
- #endif
333
- #endif // End of parameterization
334
- __global__ __launch_bounds__(BLOCK_SIZE) void gemm_kernel(
335
- const in_data_type a[M][LDA], const in_data_type b[N][LDB], out_data_type c[M][N],
336
- const float sa[ceil_div(K, QUANT_SIZE)][M / 1], // Assuming M is divisible by 1 (always true)
337
- const float sb[ceil_div(K, QUANT_SIZE)][ceil_div(N, QUANT_SIZE)]) {
338
- // --- Start: Derived parameters and constants ---
339
- constexpr int WMMA_M = 16; // Fixed WMMA dimension M
340
- constexpr int WMMA_N = 16; // Fixed WMMA dimension N
341
- constexpr int WMMA_K = 32; // Fixed WMMA dimension K (for FP8)
342
-
343
- // WARP_M/N define the 2D arrangement of warps in the block grid.
344
- // These might need adjustment based on BLOCK_DIM_X/Y strategy.
345
- // Using fixed values based on the non-parameterized version for now.
346
- // TODO: Derive WARP_M/N from BLOCK_DIM_X/Y if a flexible strategy is needed.
347
- constexpr int WARP_NUM = WARP_M * WARP_N; // Total warps per block
348
-
349
- // Assertion: Check if the assumed warp layout matches the block size
350
- static_assert(WARP_NUM * WAVE_SIZE == BLOCK_SIZE, "WARP_M * WARP_N * WAVE_SIZE must equal BLOCK_SIZE");
351
-
352
- // Fragments per warp
353
- constexpr int FRAG_M_PER_WARP = BM / WMMA_M / WARP_M;
354
- constexpr int FRAG_N_PER_WARP = BN / WMMA_N / WARP_N;
355
- constexpr int FRAG_K = BK / WMMA_K; // Fragments along K dimension tile
356
-
357
- static_assert(BM % (WMMA_M * WARP_M) == 0, "BM must be divisible by WMMA_M * WARP_M");
358
- static_assert(BN % (WMMA_N * WARP_N) == 0, "BN must be divisible by WMMA_N * WARP_N");
359
- static_assert(BK % WMMA_K == 0, "BK must be divisible by WMMA_K");
360
- static_assert(BK >= 32, "BK must be at least 32");
361
- // --- End: Derived parameters and constants ---
362
-
363
- constexpr int QM = M; // Dimension M for scale A
364
- constexpr int QN = ceil_div(N, QUANT_SIZE); // Dimension N for scale B (quantized)
365
- constexpr int QK = ceil_div(K, QUANT_SIZE); // Dimension K for scales (quantized)
366
- constexpr int PM = BM; // Block size M for scale A * B
367
- constexpr int PN = ceil_div(BN, QUANT_SIZE); // Block size N for scale A * B
368
-
369
- // Ensure derived fragment counts are positive
370
- static_assert(FRAG_M_PER_WARP > 0, "FRAG_M_PER_WARP must be positive");
371
- static_assert(FRAG_N_PER_WARP > 0, "FRAG_N_PER_WARP must be positive");
372
- static_assert(FRAG_K > 0, "FRAG_K must be positive");
373
-
374
- using FragA = wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::row_major>;
375
- using FragB = wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::col_major>;
376
- using FragC = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, acc_data_type>;
377
- using FragOut = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
378
- half>; // Output uses half for storage via bfloat16 reinterpret
379
-
380
- __shared__ in_data_type s_a[BM][BK + 8];
381
- __shared__ in_data_type s_b[BN][BK + 8];
382
- __shared__ acc_data_type s_s[PN][PM]; // Accumulator type for scales
383
- FragC frag_r[FRAG_M_PER_WARP][FRAG_N_PER_WARP]; // Accumulator fragments
384
-
385
- // handle splitk
386
- a = (decltype(a))((in_data_type *)a + blockIdx.z * K);
387
- b = (decltype(b))((in_data_type *)b + blockIdx.z * K);
388
- c += blockIdx.z * M;
389
- sa += blockIdx.z * QK;
390
- sb += blockIdx.z * QK;
391
-
392
- int tid = threadIdx.x; // Linear thread ID within the block
393
- int wid = tid / WAVE_SIZE; // Warp ID within the block
394
-
395
- // Spilt and compute fragments
396
- constexpr int iteration_over_k = ceil_div(K, BK); // Use ceil_div for potentially non-divisible K
397
- static_assert(LOAD_BATCH_SIZE > 0, "LOAD_BATCH_SIZE must be positive");
398
-
399
- constexpr auto PIPELINE = true;
400
- // using LoadVec = rocwmma::VecT<float, LOAD_BATCH_SIZE / sizeof(float)>;
401
- using LoadVec = __attribute__((__vector_size__(LOAD_BATCH_SIZE))) float;
402
- static_assert(((BK * BM) % (BLOCK_SIZE * LOAD_BATCH_SIZE)) == 0,
403
- "BK * BM must be divisible by BLOCK_SIZE * LOAD_BATCH_SIZE");
404
- static_assert(BK % LOAD_BATCH_SIZE == 0, "BK must be divisible by LOAD_BATCH_SIZE");
405
- LoadVec reg_a[BK * BM / BLOCK_SIZE / LOAD_BATCH_SIZE];
406
- LoadVec reg_b[BK * BN / BLOCK_SIZE / LOAD_BATCH_SIZE];
407
- constexpr auto PK = ceil_div(BK, QUANT_SIZE);
408
- static_assert(PK == 1, "PK must be 1 for now");
409
- float reg_sa[ceil_div(PM, BLOCK_SIZE)];
410
- float reg_sb[ceil_div(PN, BLOCK_SIZE)];
411
-
412
- // threadblock swizzle
413
- auto log_tile = 1;
414
- auto block_idx_x = blockIdx.x >> log_tile;
415
- auto block_idx_y = (blockIdx.y << log_tile) + ((blockIdx.x) & ((1 << (log_tile)) - 1));
416
- if (block_idx_x >= ceil_div(N, BN) || block_idx_y >= ceil_div(M, BM)) {
417
- return;
418
- }
419
-
420
- const int m = block_idx_y * BM;
421
- const int n = block_idx_x * BN;
422
- int k = 0;
423
-
424
- auto global2reg = [&]() {
425
- #pragma unroll
426
- for (int reg = 0; reg < sizeof(reg_sa) / sizeof(float); reg++) {
427
- // NOTE: must iter over reg to make compiler unroll the loop
428
- // and thus be able to allocate reg_a on register instead of on scratch memroy
429
- int t = tid + reg * BLOCK_SIZE;
430
- // NOTE: don't branch here
431
- // if (t > PM) {
432
- // break;
433
- // }
434
- int i = t / PM;
435
- int j = t % PM;
436
- reg_sa[reg] = sa[k / QUANT_SIZE][(m + j)];
437
- }
438
- #pragma unroll
439
- for (int reg = 0; reg < sizeof(reg_sb) / sizeof(float); reg++) {
440
- // NOTE: must iter over reg to make compiler unroll the loop
441
- // and thus be able to allocate reg_a on register instead of on scratch memroy
442
- int t = tid + reg * BLOCK_SIZE;
443
- // NOTE: don't branch here
444
- // if (t > PN) {
445
- // break;
446
- // }
447
- int i = t / PN;
448
- int j = t % PN;
449
- reg_sb[reg] = sb[k / QUANT_SIZE][(n) / QUANT_SIZE + j];
450
- }
451
- #pragma unroll
452
- for (int reg = 0; reg < sizeof(reg_a) / sizeof(LoadVec); reg++) {
453
- // NOTE: must iter over reg to make compiler unroll the loop
454
- // and thus be able to allocate reg_a on register instead of on scratch memroy
455
- int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
456
- int i = t / BK;
457
- int j = t % BK;
458
- reg_a[reg] = *(LoadVec *)&a[m + i][k + j];
459
- }
460
- #pragma unroll
461
- for (int reg = 0; reg < sizeof(reg_b) / sizeof(LoadVec); reg++) {
462
- // NOTE: must iter over reg to make compiler unroll the loop
463
- // and thus be able to allocate reg_a on register instead of on scratch memroy
464
- int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
465
- int i = t / BK;
466
- int j = t % BK;
467
- reg_b[reg] = *(LoadVec *)&b[n + i][k + j];
468
- }
469
- };
470
-
471
- auto reg2lds = [&]() {
472
- #pragma unroll
473
- for (int rega = 0; rega < sizeof(reg_sa) / sizeof(float); rega++) {
474
- int ta = tid + rega * BLOCK_SIZE;
475
- int j = ta % PM;
476
- #pragma unroll
477
- for (int regb = 0; regb < sizeof(reg_sb) / sizeof(float); regb++) {
478
- int tb = tid + regb * BLOCK_SIZE;
479
- int i = tb % PN;
480
- s_s[i][j] = reg_sa[rega] * reg_sb[regb];
481
- }
482
- }
483
- #pragma unroll
484
- for (int reg = 0; reg < sizeof(reg_a) / sizeof(LoadVec); reg++) {
485
- int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
486
- int i = t / BK;
487
- int j = t % BK;
488
- *(LoadVec *)&s_a[i][j] = reg_a[reg];
489
- }
490
- #pragma unroll
491
- for (int reg = 0; reg < sizeof(reg_b) / sizeof(LoadVec); reg++) {
492
- int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
493
- int i = t / BK;
494
- int j = t % BK;
495
- *(LoadVec *)&s_b[i][j] = reg_b[reg];
496
- }
497
- };
498
-
499
- if constexpr (PIPELINE) {
500
- global2reg();
501
- }
502
-
503
- // Initialize the output accumulator fragments to zero
504
- #pragma unroll
505
- for (int i = 0; i < FRAG_M_PER_WARP; i++) {
506
- #pragma unroll
507
- for (int j = 0; j < FRAG_N_PER_WARP; j++) {
508
- wmma::fill_fragment(frag_r[i][j], 0.0f); // Use float literal
509
- }
510
- }
511
-
512
- if constexpr (!PIPELINE) {
513
- global2reg();
514
- }
515
-
516
- reg2lds();
517
-
518
- for (int bk = 1; bk < iteration_over_k; bk++) {
519
- k = bk * BK;
520
-
521
- // Calculate remaining K for boundary checks if needed (not currently used by load_input)
522
- // const int k_rem = K - k;
523
-
524
- // Load data into shared memory
525
- // load_input<in_data_type, BK, BM, K, M, BLOCK_SIZE, 32>(
526
- // s_a, a, m, k);
527
- // load_input<in_data_type, BK, BN, K, N, BLOCK_SIZE, 32>(
528
- // s_b, b, n, k);
529
- // Load scales into shared memory (using acc_data_type for s_s)
530
- // load_scale<PM, PN, QM, QN, QK, QUANT_SIZE, BLOCK_SIZE, 1>(
531
- // s_s, sa, sb, m, n, k);
532
-
533
- if constexpr (PIPELINE) {
534
- global2reg();
535
- }
536
-
537
- __syncthreads();
538
-
539
- // Perform matrix multiplication using WMMA
540
- wmma_compute<in_data_type, acc_data_type, FragC, FragA, FragB, PM, PN, BM, BN, BK, FRAG_M_PER_WARP,
541
- FRAG_N_PER_WARP, FRAG_K, WMMA_M, WMMA_N, WMMA_K, WARP_M, WARP_N, BLOCK_SIZE, LOAD_BATCH_SIZE,
542
- QUANT_SIZE>( // Pass calculated BLOCK_SIZE and LOAD_BATCH_SIZE
543
- s_a, s_b, s_s, frag_r, wid / WARP_N, wid % WARP_N);
544
- __syncthreads();
545
-
546
- if constexpr (!PIPELINE) {
547
- global2reg();
548
- }
549
-
550
- // __builtin_amdgcn_sched_barrier(0);
551
-
552
- reg2lds();
553
- }
554
- __syncthreads();
555
- wmma_compute<in_data_type, acc_data_type, FragC, FragA, FragB, PM, PN, BM, BN, BK, FRAG_M_PER_WARP, FRAG_N_PER_WARP,
556
- FRAG_K, WMMA_M, WMMA_N, WMMA_K, WARP_M, WARP_N, BLOCK_SIZE, LOAD_BATCH_SIZE,
557
- QUANT_SIZE>( // Pass calculated BLOCK_SIZE and LOAD_BATCH_SIZE
558
- s_a, s_b, s_s, frag_r, wid / WARP_N, wid % WARP_N);
559
- // Store results from accumulator fragments to global memory
560
- store_result<acc_data_type, out_data_type, FragC, FragOut, WMMA_M, WMMA_N, BM, BN, M, N, FRAG_M_PER_WARP,
561
- FRAG_N_PER_WARP>(c, frag_r, block_idx_y * BM, block_idx_x * BN, wid / WARP_N, wid % WARP_N);
562
- };
563
-
564
- }; // namespace gemm_kernel
565
-
566
- HOST_CODE_BELOW
567
-
568
- #ifndef PARAMETERIZE_LIBRARY
569
- // Define type aliases to match those in the namespace
570
- using fp8_type = gemm_kernel::in_data_type; // __hip_fp8_e4m3
571
- using fp16_type = gemm_kernel::out_data_type; // __hip_bfloat16
572
- using acc_data_type = gemm_kernel::acc_data_type; // float
573
-
574
- // Define constants to match those in the namespace
575
- constexpr int M = gemm_kernel::M; // 4096
576
- constexpr int N = gemm_kernel::N; // 4096
577
- constexpr int K = gemm_kernel::K; // 4096
578
- constexpr int BM = gemm_kernel::BM; // 256
579
- constexpr int BN = gemm_kernel::BN; // 128
580
- constexpr int BK = gemm_kernel::BK; // 32
581
- constexpr int BLOCK_SIZE = gemm_kernel::BLOCK_SIZE;
582
- constexpr int QUANT_SIZE = gemm_kernel::QUANT_SIZE; // 128
583
-
584
- // Define derived constants for the test
585
- constexpr int QK = K / QUANT_SIZE;
586
- constexpr int QM = M;
587
- constexpr int QN = N / QUANT_SIZE;
588
-
589
- // Helper function to check HIP errors
590
- #define CHECK_HIP_ERROR(val) check((val), #val, __FILE__, __LINE__)
591
- template <typename T> void check(T err, const char *const func, const char *const file, const int line) {
592
- if (err != hipSuccess) {
593
- fprintf(stderr, "HIP Runtime Error at: %s:%d\n", file, line);
594
- fprintf(stderr, "%s %s\n", hipGetErrorString(err), func);
595
- exit(1);
596
- }
597
- }
598
-
599
- // Define a macro to check HIP errors
600
- #define HIP_CALL(call) \
601
- do { \
602
- hipError_t err = call; \
603
- if (err != hipSuccess) { \
604
- fprintf(stderr, "HIP Error: %s at %s:%d\n", hipGetErrorString(err), __FILE__, __LINE__); \
605
- exit(EXIT_FAILURE); \
606
- } \
607
- } while (0)
608
-
609
- // CPU matrix multiplication implementation for result verification
610
- void cpu_gemm(const fp8_type a[K][M], const fp8_type b[K][N], fp16_type c[M][N], const float sa[QK][QM],
611
- const float sb[QK][QN]) {
612
- float(*rc)[N] = new float[M][N];
613
- for (int m = 0; m < M; ++m) {
614
- for (int n = 0; n < N; ++n) {
615
- rc[m][n] = 0.0f;
616
- }
617
- }
618
- for (int k = 0; k < K; ++k) {
619
- for (int m = 0; m < M; ++m) {
620
- for (int n = 0; n < N; ++n) {
621
- float scale = sa[k / QUANT_SIZE][m] * sb[k / QUANT_SIZE][n / QUANT_SIZE];
622
- rc[m][n] += (scale * (float)a[k][m] * (float)b[k][n]);
623
- }
624
- }
625
- }
626
- for (int m = 0; m < M; ++m) {
627
- for (int n = 0; n < N; ++n) {
628
- c[m][n] = (fp16_type)rc[m][n];
629
- }
630
- }
631
- delete[] rc;
632
- }
633
-
634
- int main() {
635
- // Allocate host memory
636
- fp8_type(*h_a)[M] = new fp8_type[K][M];
637
- fp8_type(*h_b)[N] = new fp8_type[K][N];
638
- fp16_type(*h_c)[N] = new fp16_type[M][N];
639
- fp16_type(*h_c_ref)[N] = new fp16_type[M][N];
640
-
641
- // Allocate host memory for quantization scale factors
642
- float(*h_sa)[QM] = new float[QK][QM];
643
- float(*h_sb)[QN] = new float[QK][QN];
644
-
645
- // Initialize input data
646
- for (int i = 0; i < K; ++i) {
647
- for (int j = 0; j < M; ++j) {
648
- h_a[i][j] = (fp8_type)((rand() % 10000) / 10000.0f);
649
- }
650
- }
651
- for (int i = 0; i < K; ++i) {
652
- for (int j = 0; j < N; ++j) {
653
- h_b[i][j] = (fp8_type)((rand() % 10000) / 10000.0f);
654
- }
655
- }
656
-
657
- // Initialize quantization scale factors
658
- for (int i = 0; i < QK; ++i) {
659
- for (int j = 0; j < QM; ++j) {
660
- h_sa[i][j] = 1.0f;
661
- }
662
- }
663
- for (int i = 0; i < QK; ++i) {
664
- for (int j = 0; j < QN; ++j) {
665
- h_sb[i][j] = 1.0f;
666
- }
667
- }
668
-
669
- // Allocate device memory
670
- fp8_type(*d_a)[K];
671
- fp8_type(*d_b)[K];
672
- fp16_type(*d_c)[N];
673
- float(*d_sa)[QM];
674
- float(*d_sb)[QN];
675
-
676
- CHECK_HIP_ERROR(hipMalloc(&d_a, K * M * sizeof(fp8_type)));
677
- CHECK_HIP_ERROR(hipMalloc(&d_b, K * N * sizeof(fp8_type)));
678
- CHECK_HIP_ERROR(hipMalloc(&d_c, M * N * sizeof(fp16_type)));
679
- CHECK_HIP_ERROR(hipMalloc(&d_sa, QK * QM * sizeof(float)));
680
- CHECK_HIP_ERROR(hipMalloc(&d_sb, QK * QN * sizeof(float)));
681
-
682
- // Copy data from host memory to device memory
683
- CHECK_HIP_ERROR(hipMemcpy(d_a, h_a, K * M * sizeof(fp8_type), hipMemcpyHostToDevice));
684
- CHECK_HIP_ERROR(hipMemcpy(d_b, h_b, K * N * sizeof(fp8_type), hipMemcpyHostToDevice));
685
- CHECK_HIP_ERROR(hipMemcpy(d_sa, h_sa, QK * QM * sizeof(float), hipMemcpyHostToDevice));
686
- CHECK_HIP_ERROR(hipMemcpy(d_sb, h_sb, QK * QN * sizeof(float), hipMemcpyHostToDevice));
687
-
688
- // Calculate grid and block sizes - ensure coverage of the entire matrix
689
- dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
690
- dim3 block(BLOCK_SIZE);
691
-
692
- // Ensure block size is a multiple of 32, since warp size is 32
693
- if (BLOCK_SIZE % 32 != 0) {
694
- printf("Error: Block size must be a multiple of warp size (32)\n");
695
- return 1;
696
- }
697
-
698
- // Check if device supports required compute capability
699
- int deviceId;
700
- HIP_CALL(hipGetDevice(&deviceId));
701
- hipDeviceProp_t deviceProp;
702
- HIP_CALL(hipGetDeviceProperties(&deviceProp, deviceId));
703
-
704
- if (deviceProp.major < 7) {
705
- printf("Error: This kernel requires a GPU with compute capability 7.0 or higher\n");
706
- return 1;
707
- }
708
-
709
- printf("Running GEMM kernel with grid(%d,%d), block(%d)...\n", grid.x, grid.y, block.x);
710
-
711
- // Query and print kernel and device information
712
- printf("Querying kernel and device information...\n");
713
-
714
- // Get device properties
715
- HIP_CALL(hipGetDeviceProperties(&deviceProp, deviceId));
716
- printf("Device Name: %s\n", deviceProp.name);
717
- printf("Total Global Memory: %lu bytes\n", deviceProp.totalGlobalMem);
718
- printf("Shared Memory per Block: %lu bytes\n", deviceProp.sharedMemPerBlock);
719
- printf("Registers per Block: %d\n", deviceProp.regsPerBlock);
720
- printf("Warp Size: %d\n", deviceProp.warpSize);
721
- printf("Max Threads per Block: %d\n", deviceProp.maxThreadsPerBlock);
722
- printf("Max Threads per Multiprocessor: %d\n", deviceProp.maxThreadsPerMultiProcessor);
723
- printf("Number of Multiprocessors: %d\n", deviceProp.multiProcessorCount);
724
-
725
- // Query kernel attributes
726
- hipFuncAttributes funcAttr;
727
- HIP_CALL(hipFuncGetAttributes(&funcAttr, (const void *)gemm_kernel::gemm_kernel));
728
- printf("Kernel Attributes:\n");
729
- printf(" Shared Memory Size: %lu bytes\n", funcAttr.sharedSizeBytes);
730
- printf(" Number of Registers: %d\n", funcAttr.numRegs);
731
- printf(" Max Threads per Block: %d\n", funcAttr.maxThreadsPerBlock);
732
- printf(" Local Memory Size: %lu bytes\n", funcAttr.localSizeBytes);
733
-
734
- // Zero the C matrix before launching kernel
735
- CHECK_HIP_ERROR(hipMemset(d_c, 0, M * N * sizeof(fp16_type)));
736
-
737
- // Perform warmup runs
738
- printf("Performing warmup runs...\n");
739
- gemm_kernel::gemm_kernel<<<grid, block>>>(d_a, d_b, d_c, d_sa, d_sb);
740
- CHECK_HIP_ERROR(hipDeviceSynchronize());
741
- gemm_kernel::gemm_kernel<<<grid, block>>>(d_a, d_b, d_c, d_sa, d_sb);
742
- CHECK_HIP_ERROR(hipDeviceSynchronize());
743
-
744
- // Declare and create timing events
745
- hipEvent_t start, stop;
746
- HIP_CALL(hipEventCreate(&start));
747
- HIP_CALL(hipEventCreate(&stop));
748
-
749
- // Ensure device synchronization before formal timing
750
- CHECK_HIP_ERROR(hipDeviceSynchronize());
751
- HIP_CALL(hipEventRecord(start));
752
-
753
- // Launch kernel
754
- printf("Launching kernel...\n");
755
- gemm_kernel::gemm_kernel<<<grid, block>>>(d_a, d_b, d_c, d_sa, d_sb);
756
-
757
- // Record end time and calculate execution time
758
- HIP_CALL(hipEventRecord(stop));
759
-
760
- // Record end time and calculate execution time
761
- HIP_CALL(hipEventSynchronize(stop));
762
- float milliseconds = 0;
763
- HIP_CALL(hipEventElapsedTime(&milliseconds, start, stop));
764
- printf("Kernel execution time: %f ms\n", milliseconds);
765
-
766
- // Check HIP errors
767
- CHECK_HIP_ERROR(hipGetLastError());
768
-
769
- // Calculate GPU performance metrics
770
- double operations = 2.0 * M * N * K; // Each multiply-add operation counts as 2 floating-point operations
771
- double seconds = milliseconds / 1000.0;
772
- double tflops = (operations / seconds) / 1e12;
773
- printf("GPU Performance: %.2f TFLOPS\n", tflops);
774
-
775
- return 0;
776
-
777
- // Copy results from device memory back to host memory
778
- CHECK_HIP_ERROR(hipMemcpy(h_c, d_c, M * N * sizeof(fp16_type), hipMemcpyDeviceToHost));
779
-
780
- // Calculate reference results
781
- printf("Computing reference result on CPU...\n");
782
- cpu_gemm(h_a, h_b, h_c_ref, h_sa, h_sb);
783
-
784
- // Print the first 10 values for comparison
785
- printf("First 10 values (GPU vs CPU):\n");
786
- int print_count = 0;
787
- for (int i = 0; i < M && print_count < 10; ++i) {
788
- for (int j = 0; j < N && print_count < 10; ++j) {
789
- printf(" [%d, %d]: GPU=%f, CPU=%f\n", i, j, (float)h_c[i][j], (float)h_c_ref[i][j]);
790
- print_count++;
791
- }
792
- }
793
-
794
- // Verify results
795
- printf("Verifying results...\n");
796
- int errors = 0;
797
- float max_abs_diff = 0.0f;
798
- float max_rel_diff = 0.0f;
799
- struct ErrorInfo {
800
- int row, col;
801
- float gpu_val, cpu_val, abs_diff, rel_diff;
802
- };
803
- ErrorInfo first_10_errors[10];
804
- ErrorInfo max_10_errors[10] = {};
805
-
806
- // Add a configurable variable for the number of errors to output
807
- int max_errors_to_output = 10; // You can modify this value as needed
808
-
809
- for (int i = 0; i < M; ++i) {
810
- for (int j = 0; j < N; ++j) {
811
- float gpu_val = (float)h_c[i][j];
812
- float cpu_val = (float)h_c_ref[i][j];
813
- float abs_diff;
814
- float rel_diff;
815
-
816
- if (std::isnan(gpu_val) || std::isnan(cpu_val)) {
817
- abs_diff = INFINITY;
818
- rel_diff = INFINITY;
819
- } else {
820
- abs_diff = abs(gpu_val - cpu_val);
821
- rel_diff = abs_diff / (abs(cpu_val) + FLT_EPSILON);
822
- }
823
-
824
- // Track max absolute and relative differences
825
- max_abs_diff = fmaxf(max_abs_diff, abs_diff);
826
- max_rel_diff = fmaxf(max_rel_diff, rel_diff);
827
-
828
- // Record first 10 errors
829
- if (errors < max_errors_to_output && (rel_diff > 1e-2 || abs_diff > 1e-3)) {
830
- first_10_errors[errors] = {i, j, gpu_val, cpu_val, abs_diff, rel_diff};
831
- }
832
-
833
- // Track top 10 largest errors
834
- if (rel_diff > 1e-2 || abs_diff > 1e-3) {
835
- errors++;
836
- for (int k = 0; k < max_errors_to_output; ++k) {
837
- if (abs_diff > max_10_errors[k].abs_diff) {
838
- for (int l = max_errors_to_output - 1; l > k; --l) {
839
- max_10_errors[l] = max_10_errors[l - 1];
840
- }
841
- max_10_errors[k] = {i, j, gpu_val, cpu_val, abs_diff, rel_diff};
842
- break;
843
- }
844
- }
845
- }
846
- }
847
- }
848
-
849
- // Print first 10 errors
850
- printf("First %d errors:\n", max_errors_to_output);
851
- for (int i = 0; i < fmin(errors, max_errors_to_output); ++i) {
852
- printf("Error at [%d, %d]: GPU=%f, CPU=%f, AbsDiff=%f, RelDiff=%f\n", first_10_errors[i].row,
853
- first_10_errors[i].col, first_10_errors[i].gpu_val, first_10_errors[i].cpu_val,
854
- first_10_errors[i].abs_diff, first_10_errors[i].rel_diff);
855
- }
856
-
857
- // Print top 10 largest errors
858
- printf("Top %d largest errors:\n", max_errors_to_output);
859
- for (int i = 0; i < max_errors_to_output && max_10_errors[i].abs_diff > 0; ++i) {
860
- printf("Error at [%d, %d]: GPU=%f, CPU=%f, AbsDiff=%f, RelDiff=%f\n", max_10_errors[i].row,
861
- max_10_errors[i].col, max_10_errors[i].gpu_val, max_10_errors[i].cpu_val, max_10_errors[i].abs_diff,
862
- max_10_errors[i].rel_diff);
863
- }
864
-
865
- printf("Max abs_diff: %f, Max rel_diff: %f\n", max_abs_diff, max_rel_diff);
866
- if (errors == 0) {
867
- printf("Test PASSED!\n");
868
- } else {
869
- printf("Test FAILED with %d errors\n", errors);
870
- }
871
-
872
- // Calculate performance
873
- double flops = 2.0 * M * N * K;
874
- double gflops = (flops * 1e-9) / (milliseconds * 1e-3);
875
- printf("Performance: %.2f GFLOPS\n", gflops);
876
-
877
- // Free memory
878
- delete[] h_a;
879
- delete[] h_b;
880
- delete[] h_c;
881
- delete[] h_c_ref;
882
- delete[] h_sa;
883
- delete[] h_sb;
884
- HIP_CALL(hipFree(d_a));
885
- HIP_CALL(hipFree(d_b));
886
- HIP_CALL(hipFree(d_c));
887
- HIP_CALL(hipFree(d_sa));
888
- HIP_CALL(hipFree(d_sb));
889
- HIP_CALL(hipEventDestroy(start));
890
- HIP_CALL(hipEventDestroy(stop));
891
-
892
- return 0;
893
- }
894
- #endif
895
- #pragma clang diagnostic pop
896
- #endif