medmekk HF Staff commited on
Commit
d5e710f
·
verified ·
1 Parent(s): 9567a9b

Delete gemm_launcher.hip

Browse files
Files changed (1) hide show
  1. gemm_launcher.hip +0 -267
gemm_launcher.hip DELETED
@@ -1,267 +0,0 @@
1
- // Wrapped of gemm kernel launcher.
2
- #include <unistd.h>
3
- #include <chrono>
4
- #define PARAMETERIZE_LIBRARY
5
- #include "gemm_kernel.h"
6
- #include "gemm_kernel_legacy.h"
7
- #include "transpose_kernel.h"
8
- #undef PARAMETERIZE_LIBRARY
9
- #include "../include/gpu_types.h"
10
- #include "../include/timer.h"
11
- #include "../tests/checker/metrics.h"
12
- #include <iostream>
13
-
14
- #include <stdio.h>
15
-
16
- HOST_CODE_BELOW
17
-
18
- std::vector<std::shared_ptr<KernelTimer>> timers;
19
-
20
- using namespace std;
21
-
22
- float *c_splitk = nullptr;
23
- __FP8_TYPE *a_trans = nullptr;
24
- __FP8_TYPE *b_trans = nullptr;
25
- constexpr int MAX_MATRIX_M = 6144;
26
- constexpr int MAX_MATRIX_N = 7168;
27
- constexpr int MAX_MATRIX_K = 7168;
28
- constexpr int MAX_SPLITK_FACTOR = 8;
29
-
30
- void init_workspace() {
31
- LIB_CALL(HOST_TYPE(Malloc)(&c_splitk, MAX_MATRIX_M * MAX_MATRIX_N * sizeof(float) * MAX_SPLITK_FACTOR));
32
- LIB_CALL(HOST_TYPE(Malloc)(&a_trans, MAX_MATRIX_M * MAX_MATRIX_K * sizeof(__FP8_TYPE)));
33
- LIB_CALL(HOST_TYPE(Malloc)(&b_trans, MAX_MATRIX_N * MAX_MATRIX_K * sizeof(__FP8_TYPE)));
34
- // LIB_CALL(HOST_TYPE(StreamCreateWithFlags)(&job_stream0, HOST_TYPE(StreamNonBlocking)));
35
- // job_stream0 = 0;
36
- }
37
-
38
-
39
- // Launch pipeline gemm kernels (most performant).
40
- // 1. Transpose input A & B.
41
- // 2. GEMM compute.
42
- // 3. Reduce (if spilt-k is enable)
43
- template <int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE,
44
- int SPLITK_FACTOR, int LOAD_BATCH_SIZE = 16>
45
- void launch_gemm(const __FP8_TYPE *a, const __FP8_TYPE *b, __BF16_TYPE *c, const float *as, const float *bs, HOST_TYPE(Stream_t) job_stream0) {
46
- static_assert(M <= MAX_MATRIX_M, "M exceeds maximum supported size");
47
- static_assert(N <= MAX_MATRIX_N, "N exceeds maximum supported size");
48
- static_assert(K <= MAX_MATRIX_K, "K exceeds maximum supported size");
49
- static_assert(SPLITK_FACTOR <= MAX_SPLITK_FACTOR, "SPLITK_FACTOR exceeds maximum supported size");
50
- if (__builtin_expect(c_splitk == nullptr, 0)) {
51
- init_workspace();
52
- LIB_CALL(hipDeviceSynchronize());
53
- }
54
-
55
- transpose_kernel::transpose_fp8<K, N>(b_trans, b, job_stream0);
56
- transpose_kernel::transpose_fp8<K, M>(a_trans, a, job_stream0);
57
- // transpose_kernel::launch_transpose<__FP8_TYPE, K, N, 64, 512, 4>(b_trans, b, job_stream0);
58
- // transpose_kernel::launch_transpose<__FP8_TYPE, K, M, 64, 512, 4>(a_trans, a, job_stream0);
59
- // Busy wait for 150 microseconds
60
- // auto start = std::chrono::high_resolution_clock::now();
61
- // while (std::chrono::duration_cast<std::chrono::microseconds>(
62
- // std::chrono::high_resolution_clock::now() - start).count() < 150) {
63
- // // Busy wait
64
- // }
65
- // be careful that blocksize < 1024, or there's a silent fault
66
- // gemm_kernel::check_trans<<<dim3(K / 32, M / 32), dim3(32, 32)>>>(a, a_trans, K, M);
67
-
68
- static_assert(K % SPLITK_FACTOR == 0, "K not divisible by SPLITK_FACTOR");
69
- dim3 grid(ceil_div(N, BN) << 1, ceil_div(M, BM) >> 1, SPLITK_FACTOR);
70
- static_assert(BLOCK_SIZE >= 32, "BLOCK_SIZE must be at least 32");
71
- dim3 block(BLOCK_SIZE);
72
- if constexpr (SPLITK_FACTOR == 1) {
73
- hipLaunchKernelGGL(
74
- HIP_KERNEL_NAME(gemm_kernel::gemm_kernel<__FP8_TYPE, float, __BF16_TYPE, M, N, K, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N, K, K, LOAD_BATCH_SIZE>),
75
- grid, block, 0, job_stream0,
76
- reinterpret_cast<const __FP8_TYPE(*)[K]>(a_trans),
77
- reinterpret_cast<const __FP8_TYPE(*)[K]>(b_trans),
78
- reinterpret_cast<__BF16_TYPE(*)[N]>(c), reinterpret_cast<const float(*)[M]>(as),
79
- reinterpret_cast<const float(*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
80
- );
81
- } else {
82
- hipLaunchKernelGGL(
83
- HIP_KERNEL_NAME(gemm_kernel::gemm_kernel<__FP8_TYPE, float, float, M, N, K / SPLITK_FACTOR, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N, K, K, LOAD_BATCH_SIZE>),
84
- grid, block, 0, job_stream0,
85
- reinterpret_cast<const __FP8_TYPE(*)[K]>(a_trans),
86
- reinterpret_cast<const __FP8_TYPE(*)[K]>(b_trans),
87
- reinterpret_cast<float(*)[N]>(c_splitk), reinterpret_cast<const float(*)[M]>(as),
88
- reinterpret_cast<const float(*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs));
89
- constexpr uint32_t REDUCE_BLOCK = 256;
90
- hipLaunchKernelGGL(
91
- HIP_KERNEL_NAME(gemm_kernel::reduce_kernel<M, N, SPLITK_FACTOR, REDUCE_BLOCK>),
92
- ceil_div(M * N / 4, REDUCE_BLOCK), REDUCE_BLOCK, 0, job_stream0,
93
- reinterpret_cast<const float(*)[M][N]>(c_splitk),
94
- reinterpret_cast<__BF16_TYPE(*)[N]>(c)
95
- ); }
96
- auto err = HOST_TYPE(GetLastError)();
97
- if (err != HOST_TYPE(Success)) {
98
- std::cerr << "Kernel execution failed.\n" << HOST_TYPE(GetErrorString)(err) << std::endl;
99
- abort();
100
- }
101
- }
102
-
103
-
104
- // Launch legacy gemm kernel. (most compellable)
105
- template <int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE, int SPLITK_FACTOR>
106
- void launch_gemm_legacy(const __FP8_TYPE *a, const __FP8_TYPE *b, __BF16_TYPE *c, const float *as, const float *bs, HOST_TYPE(Stream_t) job_stream0) {
107
- static_assert(K % SPLITK_FACTOR == 0, "K not divisible by SPLITK_FACTOR");
108
- dim3 grid(ceil_div(N, BN), ceil_div(M, BM), SPLITK_FACTOR);
109
- static_assert(BLOCK_SIZE >= 32, "BLOCK_SIZE must be at least 32");
110
- dim3 block(BLOCK_SIZE);
111
- if (__builtin_expect(c_splitk == nullptr, 0)) {
112
- init_workspace();
113
- LIB_CALL(hipDeviceSynchronize());
114
- }
115
-
116
- if constexpr (SPLITK_FACTOR == 1) {
117
- hipLaunchKernelGGL(
118
- HIP_KERNEL_NAME(gemm_kernel_legacy::gemm_kernel<__FP8_TYPE, float, __BF16_TYPE, M, N, K, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N>),
119
- grid, block, 0, job_stream0,
120
- reinterpret_cast<const __FP8_TYPE (*)[M]>(a),
121
- reinterpret_cast<const __FP8_TYPE (*)[N]>(b),
122
- reinterpret_cast<__BF16_TYPE (*)[N]>(c),
123
- reinterpret_cast<const float (*)[M]>(as),
124
- reinterpret_cast<const float (*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
125
- );
126
- } else {
127
- hipLaunchKernelGGL(
128
- HIP_KERNEL_NAME(gemm_kernel_legacy::gemm_kernel<__FP8_TYPE, float, float, M, N, K / SPLITK_FACTOR, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N>),
129
- grid, block, 0, job_stream0,
130
- reinterpret_cast<const __FP8_TYPE (*)[M]>(a),
131
- reinterpret_cast<const __FP8_TYPE (*)[N]>(b),
132
- reinterpret_cast<float (*)[N]>(c_splitk),
133
- reinterpret_cast<const float (*)[M]>(as),
134
- reinterpret_cast<const float (*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
135
- );
136
- constexpr uint32_t REDUCE_BLOCK = 256;
137
- hipLaunchKernelGGL(
138
- HIP_KERNEL_NAME(gemm_kernel_legacy::reduce<0>),
139
- ceil_div(M * N, REDUCE_BLOCK), REDUCE_BLOCK, 0, job_stream0,
140
- M, N, SPLITK_FACTOR, c_splitk, (__BF16_TYPE *)c
141
- );
142
- }
143
- auto err = HOST_TYPE(GetLastError)();
144
- if (err != HOST_TYPE(Success)) {
145
- std::cerr << "Kernel execution failed.\n" << HOST_TYPE(GetErrorString)(err) << std::endl;
146
- abort();
147
- }
148
- }
149
-
150
- constexpr inline uint32_t pack_shape(uint32_t m, uint32_t n, uint32_t k) {
151
- // Pack m, n, k into a 32-bit integer
152
- // Use 8 bits for each dimension (supports 32-aligned values from 32 to 8192)
153
- // Divide each value by 32 to fit into 8 bits
154
- return ((m / 32) << 16) | ((n / 32) << 8) | (k / 32);
155
- }
156
- // int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE, int
157
- // SPLITK_FACTOR, int LOAD_BATCH_SIZE
158
- #define DISPATCH_GEMM(M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE) \
159
- case pack_shape_checked<M, N, K>(): { \
160
- launch_gemm<M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, 128, SPLITK_FACTOR, LOAD_BATCH_SIZE>(a_ptr, b_ptr, c_ptr, as_ptr, bs_ptr, job_stream0); \
161
- break; \
162
- }
163
-
164
- #define DISPATCH_GEMM_LEGACY(M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR) \
165
- case pack_shape_checked<M, N, K>(): { \
166
- launch_gemm_legacy<M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, 128, SPLITK_FACTOR>(a_ptr, b_ptr, c_ptr, as_ptr, bs_ptr, job_stream0); \
167
- break; \
168
- }
169
-
170
- template <int M, int N, int K> constexpr inline uint32_t pack_shape_checked() {
171
- static_assert(M % 32 == 0, "M must be a multiple of 32");
172
- static_assert(N % 32 == 0, "N must be a multiple of 32");
173
- static_assert(K % 32 == 0, "K must be a multiple of 32");
174
- static_assert(M >= 32 && M <= 8192, "M must be between 32 and 8192");
175
- static_assert(N >= 32 && N <= 8192, "N must be between 32 and 8192");
176
- static_assert(K >= 32 && K <= 8192, "K must be between 32 and 8192");
177
- return pack_shape(M, N, K);
178
- }
179
-
180
-
181
-
182
- extern "C" {
183
- // Basically, it dispatch GEMM to fatest implementations according to inputs' shape.
184
- void run(void *a, void *b, void *as, void *bs, void *c, int m, int n, int k, PerfMetrics *metrics, hipStream_t job_stream0) {
185
- // Cast pointers once
186
- const __FP8_TYPE *a_ptr = static_cast<const __FP8_TYPE *>(a);
187
- const __FP8_TYPE *b_ptr = static_cast<const __FP8_TYPE *>(b);
188
- __BF16_TYPE *c_ptr = static_cast<__BF16_TYPE *>(c);
189
- const float *as_ptr = static_cast<const float *>(as);
190
- const float *bs_ptr = static_cast<const float *>(bs);
191
- KernelTimerScoped timer(timers, 2LL * m * n * k,
192
- metrics ? &metrics->entries[0].time : nullptr,
193
- metrics ? &metrics->entries[0].gflops : nullptr, job_stream0);
194
-
195
- switch (pack_shape(m, n, k)) {
196
- #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
197
- // Test: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE
198
- DISPATCH_GEMM(64, 64, 128, 64, 64, 32, 1, 4, 128, 1, 16);
199
- DISPATCH_GEMM(64, 1536, 7168, 64, 128, 64, 4, 2, 256, 1, 16);
200
- DISPATCH_GEMM(64, 3072, 1536, 64, 128, 64, 4, 2, 256, 1, 16);
201
- DISPATCH_GEMM(64, 576, 7168, 64, 128, 64, 4, 2, 256, 1, 16);
202
- DISPATCH_GEMM(96, 7168, 256, 96, 256, 64, 2, 4, 256, 1, 16);
203
- DISPATCH_GEMM(96, 7168, 2048, 96, 256, 64, 2, 4, 256, 1, 16);
204
- DISPATCH_GEMM(96, 4608, 7168, 96, 256, 64, 2, 4, 256, 1, 16);
205
- DISPATCH_GEMM(128, 7168, 2304, 128, 128, 64, 4, 2, 256, 1, 16);
206
- DISPATCH_GEMM(128, 512, 7168, 128, 128, 64, 4, 2, 256, 1, 16);
207
- DISPATCH_GEMM(512, 4096, 512, 256, 128, 64, 4, 2, 256, 1, 16);
208
- DISPATCH_GEMM(512, 1536, 7168, 256, 128, 64, 4, 2, 256, 1, 16);
209
- // Benchmark: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE
210
- DISPATCH_GEMM(1024, 1536, 7168, 128, 128, 64, 1, 4, 128, 4, 16); // Optimized: 0.49 ms (45.65 TFlops)
211
- DISPATCH_GEMM(1024, 3072, 1536, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.19 ms (51.32 TFlops)
212
- DISPATCH_GEMM(1024, 576, 7168, 128, 64, 32, 4, 1, 128, 4, 16); // Optimized: 0.30 ms (28.16 TFlops)
213
- DISPATCH_GEMM(1024, 7168, 256, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.08 ms (46.49 TFlops)
214
- DISPATCH_GEMM(1024, 7168, 2048, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.49 ms (61.92 TFlops)
215
- DISPATCH_GEMM(1024, 4608, 7168, 128, 128, 32, 2, 2, 128, 1, 16); // Optimized: 0.99 ms (68.16 TFlops)
216
- DISPATCH_GEMM(1024, 7168, 2304, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.51 ms (66.04 TFlops)
217
- DISPATCH_GEMM(1024, 512, 7168, 64, 128, 32, 2, 2, 128, 4, 16); // Optimized: 0.26 ms (28.97 TFlops)
218
- DISPATCH_GEMM(1024, 4096, 512, 128, 256, 32, 2, 4, 256, 1, 16); // Optimized: 0.08 ms (54.27 TFlops)
219
- DISPATCH_GEMM(6144, 1536, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 1.76 ms (76.76 TFlops)
220
- DISPATCH_GEMM(6144, 3072, 1536, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.88 ms (66.00 TFlops)
221
- DISPATCH_GEMM(6144, 576, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.84 ms (60.68 TFlops)
222
- DISPATCH_GEMM(6144, 7168, 256, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.49 ms (45.76 TFlops)
223
- DISPATCH_GEMM(6144, 7168, 2048, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 2.17 ms (83.11 TFlops)
224
- DISPATCH_GEMM(6144, 4608, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 4.56 ms (88.99 TFlops)
225
- DISPATCH_GEMM(6144, 7168, 2304, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 2.41 ms (84.32 TFlops)
226
- DISPATCH_GEMM(6144, 512, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.67 ms (67.45 TFlops)
227
- DISPATCH_GEMM(6144, 4096, 512, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.51 ms (50.79 TFlops)
228
- #else // CDNA3, WAVE_SIZE = 64
229
- // Benchmark: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SZ, SPLITK_F, LOAD_BS
230
- DISPATCH_GEMM(1024, 1536, 7168, 256, 128, 128, 4, 2, 512, 4, 16); // #0
231
- DISPATCH_GEMM(1024, 3072, 1536, 256, 128, 128, 4, 2, 512, 2, 16); // #1
232
- DISPATCH_GEMM(1024, 576, 7168, 256, 128, 128, 4, 2, 512, 8, 16); // #2
233
- DISPATCH_GEMM(1024, 7168, 256, 256, 128, 128, 4, 2, 512, 1, 16); // #3
234
- DISPATCH_GEMM(1024, 7168, 2048, 256, 128, 128, 4, 2, 512, 1, 16); // #4
235
- DISPATCH_GEMM(1024, 4608, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #5
236
- DISPATCH_GEMM(1024, 7168, 2304, 256, 128, 128, 4, 2, 512, 1, 16); // #6
237
- DISPATCH_GEMM(1024, 512, 7168, 256, 128, 128, 4, 2, 512, 8, 16); // #7
238
- DISPATCH_GEMM(1024, 4096, 512, 256, 128, 128, 4, 2, 512, 1, 16); // #8
239
- DISPATCH_GEMM(6144, 1536, 7168, 256, 128, 128, 4, 2, 512, 1, 16); // #9
240
- DISPATCH_GEMM(6144, 3072, 1536, 256, 128, 128, 4, 2, 512, 1, 16); // #10
241
- DISPATCH_GEMM(6144, 576, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #11
242
- DISPATCH_GEMM(6144, 7168, 256, 256, 128, 128, 4, 2, 512, 1, 16); // #12
243
- DISPATCH_GEMM(6144, 7168, 2048, 256, 128, 128, 4, 2, 512, 1, 16); // #13
244
- DISPATCH_GEMM(6144, 4608, 7168, 256, 128, 128, 4, 2, 512, 1, 16); // #14
245
- DISPATCH_GEMM(6144, 7168, 2304, 256, 128, 128, 4, 2, 512, 1, 16); // #15
246
- DISPATCH_GEMM(6144, 512, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #16
247
- DISPATCH_GEMM(6144, 4096, 512, 256, 128, 128, 4, 2, 512, 1, 16); // #17
248
- // Test: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SZ, SPLITK_F,
249
- DISPATCH_GEMM_LEGACY(64, 64, 128, 64, 64, 32, 4, 2, 512, 1);
250
- DISPATCH_GEMM_LEGACY(64, 1536, 7168, 64, 128, 64, 4, 2, 512, 1);
251
- DISPATCH_GEMM_LEGACY(64, 3072, 1536, 64, 128, 64, 4, 2, 512, 1);
252
- DISPATCH_GEMM_LEGACY(64, 576, 7168, 64, 128, 64, 4, 2, 512, 1);
253
- DISPATCH_GEMM_LEGACY(96, 7168, 256, 96, 256, 64, 2, 4, 512, 1);
254
- DISPATCH_GEMM_LEGACY(96, 7168, 2048, 96, 256, 64, 2, 4, 512, 1);
255
- DISPATCH_GEMM_LEGACY(96, 4608, 7168, 96, 256, 64, 2, 4, 512, 1);
256
- DISPATCH_GEMM_LEGACY(128, 7168, 2304, 128, 128, 64, 4, 2, 512, 1);
257
- DISPATCH_GEMM_LEGACY(128, 512, 7168, 128, 128, 64, 4, 2, 512, 1);
258
- DISPATCH_GEMM_LEGACY(512, 4096, 512, 256, 128, 64, 4, 2, 512, 1);
259
- DISPATCH_GEMM_LEGACY(512, 1536, 7168, 256, 128, 64, 4, 2, 512, 1);
260
- #endif
261
- default: {
262
- printf("Error: Unsupported shape M=%d, K=%d, N=%d\n", m, k, n);
263
- abort();
264
- }
265
- }
266
- }
267
- } // extern "C"