Delete gemm_launcher.hip
Browse files- gemm_launcher.hip +0 -267
gemm_launcher.hip
DELETED
|
@@ -1,267 +0,0 @@
|
|
| 1 |
-
// Wrapped of gemm kernel launcher.
|
| 2 |
-
#include <unistd.h>
|
| 3 |
-
#include <chrono>
|
| 4 |
-
#define PARAMETERIZE_LIBRARY
|
| 5 |
-
#include "gemm_kernel.h"
|
| 6 |
-
#include "gemm_kernel_legacy.h"
|
| 7 |
-
#include "transpose_kernel.h"
|
| 8 |
-
#undef PARAMETERIZE_LIBRARY
|
| 9 |
-
#include "../include/gpu_types.h"
|
| 10 |
-
#include "../include/timer.h"
|
| 11 |
-
#include "../tests/checker/metrics.h"
|
| 12 |
-
#include <iostream>
|
| 13 |
-
|
| 14 |
-
#include <stdio.h>
|
| 15 |
-
|
| 16 |
-
HOST_CODE_BELOW
|
| 17 |
-
|
| 18 |
-
std::vector<std::shared_ptr<KernelTimer>> timers;
|
| 19 |
-
|
| 20 |
-
using namespace std;
|
| 21 |
-
|
| 22 |
-
float *c_splitk = nullptr;
|
| 23 |
-
__FP8_TYPE *a_trans = nullptr;
|
| 24 |
-
__FP8_TYPE *b_trans = nullptr;
|
| 25 |
-
constexpr int MAX_MATRIX_M = 6144;
|
| 26 |
-
constexpr int MAX_MATRIX_N = 7168;
|
| 27 |
-
constexpr int MAX_MATRIX_K = 7168;
|
| 28 |
-
constexpr int MAX_SPLITK_FACTOR = 8;
|
| 29 |
-
|
| 30 |
-
void init_workspace() {
|
| 31 |
-
LIB_CALL(HOST_TYPE(Malloc)(&c_splitk, MAX_MATRIX_M * MAX_MATRIX_N * sizeof(float) * MAX_SPLITK_FACTOR));
|
| 32 |
-
LIB_CALL(HOST_TYPE(Malloc)(&a_trans, MAX_MATRIX_M * MAX_MATRIX_K * sizeof(__FP8_TYPE)));
|
| 33 |
-
LIB_CALL(HOST_TYPE(Malloc)(&b_trans, MAX_MATRIX_N * MAX_MATRIX_K * sizeof(__FP8_TYPE)));
|
| 34 |
-
// LIB_CALL(HOST_TYPE(StreamCreateWithFlags)(&job_stream0, HOST_TYPE(StreamNonBlocking)));
|
| 35 |
-
// job_stream0 = 0;
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
// Launch pipeline gemm kernels (most performant).
|
| 40 |
-
// 1. Transpose input A & B.
|
| 41 |
-
// 2. GEMM compute.
|
| 42 |
-
// 3. Reduce (if spilt-k is enable)
|
| 43 |
-
template <int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE,
|
| 44 |
-
int SPLITK_FACTOR, int LOAD_BATCH_SIZE = 16>
|
| 45 |
-
void launch_gemm(const __FP8_TYPE *a, const __FP8_TYPE *b, __BF16_TYPE *c, const float *as, const float *bs, HOST_TYPE(Stream_t) job_stream0) {
|
| 46 |
-
static_assert(M <= MAX_MATRIX_M, "M exceeds maximum supported size");
|
| 47 |
-
static_assert(N <= MAX_MATRIX_N, "N exceeds maximum supported size");
|
| 48 |
-
static_assert(K <= MAX_MATRIX_K, "K exceeds maximum supported size");
|
| 49 |
-
static_assert(SPLITK_FACTOR <= MAX_SPLITK_FACTOR, "SPLITK_FACTOR exceeds maximum supported size");
|
| 50 |
-
if (__builtin_expect(c_splitk == nullptr, 0)) {
|
| 51 |
-
init_workspace();
|
| 52 |
-
LIB_CALL(hipDeviceSynchronize());
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
transpose_kernel::transpose_fp8<K, N>(b_trans, b, job_stream0);
|
| 56 |
-
transpose_kernel::transpose_fp8<K, M>(a_trans, a, job_stream0);
|
| 57 |
-
// transpose_kernel::launch_transpose<__FP8_TYPE, K, N, 64, 512, 4>(b_trans, b, job_stream0);
|
| 58 |
-
// transpose_kernel::launch_transpose<__FP8_TYPE, K, M, 64, 512, 4>(a_trans, a, job_stream0);
|
| 59 |
-
// Busy wait for 150 microseconds
|
| 60 |
-
// auto start = std::chrono::high_resolution_clock::now();
|
| 61 |
-
// while (std::chrono::duration_cast<std::chrono::microseconds>(
|
| 62 |
-
// std::chrono::high_resolution_clock::now() - start).count() < 150) {
|
| 63 |
-
// // Busy wait
|
| 64 |
-
// }
|
| 65 |
-
// be careful that blocksize < 1024, or there's a silent fault
|
| 66 |
-
// gemm_kernel::check_trans<<<dim3(K / 32, M / 32), dim3(32, 32)>>>(a, a_trans, K, M);
|
| 67 |
-
|
| 68 |
-
static_assert(K % SPLITK_FACTOR == 0, "K not divisible by SPLITK_FACTOR");
|
| 69 |
-
dim3 grid(ceil_div(N, BN) << 1, ceil_div(M, BM) >> 1, SPLITK_FACTOR);
|
| 70 |
-
static_assert(BLOCK_SIZE >= 32, "BLOCK_SIZE must be at least 32");
|
| 71 |
-
dim3 block(BLOCK_SIZE);
|
| 72 |
-
if constexpr (SPLITK_FACTOR == 1) {
|
| 73 |
-
hipLaunchKernelGGL(
|
| 74 |
-
HIP_KERNEL_NAME(gemm_kernel::gemm_kernel<__FP8_TYPE, float, __BF16_TYPE, M, N, K, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N, K, K, LOAD_BATCH_SIZE>),
|
| 75 |
-
grid, block, 0, job_stream0,
|
| 76 |
-
reinterpret_cast<const __FP8_TYPE(*)[K]>(a_trans),
|
| 77 |
-
reinterpret_cast<const __FP8_TYPE(*)[K]>(b_trans),
|
| 78 |
-
reinterpret_cast<__BF16_TYPE(*)[N]>(c), reinterpret_cast<const float(*)[M]>(as),
|
| 79 |
-
reinterpret_cast<const float(*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
|
| 80 |
-
);
|
| 81 |
-
} else {
|
| 82 |
-
hipLaunchKernelGGL(
|
| 83 |
-
HIP_KERNEL_NAME(gemm_kernel::gemm_kernel<__FP8_TYPE, float, float, M, N, K / SPLITK_FACTOR, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N, K, K, LOAD_BATCH_SIZE>),
|
| 84 |
-
grid, block, 0, job_stream0,
|
| 85 |
-
reinterpret_cast<const __FP8_TYPE(*)[K]>(a_trans),
|
| 86 |
-
reinterpret_cast<const __FP8_TYPE(*)[K]>(b_trans),
|
| 87 |
-
reinterpret_cast<float(*)[N]>(c_splitk), reinterpret_cast<const float(*)[M]>(as),
|
| 88 |
-
reinterpret_cast<const float(*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs));
|
| 89 |
-
constexpr uint32_t REDUCE_BLOCK = 256;
|
| 90 |
-
hipLaunchKernelGGL(
|
| 91 |
-
HIP_KERNEL_NAME(gemm_kernel::reduce_kernel<M, N, SPLITK_FACTOR, REDUCE_BLOCK>),
|
| 92 |
-
ceil_div(M * N / 4, REDUCE_BLOCK), REDUCE_BLOCK, 0, job_stream0,
|
| 93 |
-
reinterpret_cast<const float(*)[M][N]>(c_splitk),
|
| 94 |
-
reinterpret_cast<__BF16_TYPE(*)[N]>(c)
|
| 95 |
-
); }
|
| 96 |
-
auto err = HOST_TYPE(GetLastError)();
|
| 97 |
-
if (err != HOST_TYPE(Success)) {
|
| 98 |
-
std::cerr << "Kernel execution failed.\n" << HOST_TYPE(GetErrorString)(err) << std::endl;
|
| 99 |
-
abort();
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
// Launch legacy gemm kernel. (most compellable)
|
| 105 |
-
template <int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE, int SPLITK_FACTOR>
|
| 106 |
-
void launch_gemm_legacy(const __FP8_TYPE *a, const __FP8_TYPE *b, __BF16_TYPE *c, const float *as, const float *bs, HOST_TYPE(Stream_t) job_stream0) {
|
| 107 |
-
static_assert(K % SPLITK_FACTOR == 0, "K not divisible by SPLITK_FACTOR");
|
| 108 |
-
dim3 grid(ceil_div(N, BN), ceil_div(M, BM), SPLITK_FACTOR);
|
| 109 |
-
static_assert(BLOCK_SIZE >= 32, "BLOCK_SIZE must be at least 32");
|
| 110 |
-
dim3 block(BLOCK_SIZE);
|
| 111 |
-
if (__builtin_expect(c_splitk == nullptr, 0)) {
|
| 112 |
-
init_workspace();
|
| 113 |
-
LIB_CALL(hipDeviceSynchronize());
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
if constexpr (SPLITK_FACTOR == 1) {
|
| 117 |
-
hipLaunchKernelGGL(
|
| 118 |
-
HIP_KERNEL_NAME(gemm_kernel_legacy::gemm_kernel<__FP8_TYPE, float, __BF16_TYPE, M, N, K, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N>),
|
| 119 |
-
grid, block, 0, job_stream0,
|
| 120 |
-
reinterpret_cast<const __FP8_TYPE (*)[M]>(a),
|
| 121 |
-
reinterpret_cast<const __FP8_TYPE (*)[N]>(b),
|
| 122 |
-
reinterpret_cast<__BF16_TYPE (*)[N]>(c),
|
| 123 |
-
reinterpret_cast<const float (*)[M]>(as),
|
| 124 |
-
reinterpret_cast<const float (*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
|
| 125 |
-
);
|
| 126 |
-
} else {
|
| 127 |
-
hipLaunchKernelGGL(
|
| 128 |
-
HIP_KERNEL_NAME(gemm_kernel_legacy::gemm_kernel<__FP8_TYPE, float, float, M, N, K / SPLITK_FACTOR, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N>),
|
| 129 |
-
grid, block, 0, job_stream0,
|
| 130 |
-
reinterpret_cast<const __FP8_TYPE (*)[M]>(a),
|
| 131 |
-
reinterpret_cast<const __FP8_TYPE (*)[N]>(b),
|
| 132 |
-
reinterpret_cast<float (*)[N]>(c_splitk),
|
| 133 |
-
reinterpret_cast<const float (*)[M]>(as),
|
| 134 |
-
reinterpret_cast<const float (*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
|
| 135 |
-
);
|
| 136 |
-
constexpr uint32_t REDUCE_BLOCK = 256;
|
| 137 |
-
hipLaunchKernelGGL(
|
| 138 |
-
HIP_KERNEL_NAME(gemm_kernel_legacy::reduce<0>),
|
| 139 |
-
ceil_div(M * N, REDUCE_BLOCK), REDUCE_BLOCK, 0, job_stream0,
|
| 140 |
-
M, N, SPLITK_FACTOR, c_splitk, (__BF16_TYPE *)c
|
| 141 |
-
);
|
| 142 |
-
}
|
| 143 |
-
auto err = HOST_TYPE(GetLastError)();
|
| 144 |
-
if (err != HOST_TYPE(Success)) {
|
| 145 |
-
std::cerr << "Kernel execution failed.\n" << HOST_TYPE(GetErrorString)(err) << std::endl;
|
| 146 |
-
abort();
|
| 147 |
-
}
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
constexpr inline uint32_t pack_shape(uint32_t m, uint32_t n, uint32_t k) {
|
| 151 |
-
// Pack m, n, k into a 32-bit integer
|
| 152 |
-
// Use 8 bits for each dimension (supports 32-aligned values from 32 to 8192)
|
| 153 |
-
// Divide each value by 32 to fit into 8 bits
|
| 154 |
-
return ((m / 32) << 16) | ((n / 32) << 8) | (k / 32);
|
| 155 |
-
}
|
| 156 |
-
// int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE, int
|
| 157 |
-
// SPLITK_FACTOR, int LOAD_BATCH_SIZE
|
| 158 |
-
#define DISPATCH_GEMM(M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE) \
|
| 159 |
-
case pack_shape_checked<M, N, K>(): { \
|
| 160 |
-
launch_gemm<M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, 128, SPLITK_FACTOR, LOAD_BATCH_SIZE>(a_ptr, b_ptr, c_ptr, as_ptr, bs_ptr, job_stream0); \
|
| 161 |
-
break; \
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
#define DISPATCH_GEMM_LEGACY(M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR) \
|
| 165 |
-
case pack_shape_checked<M, N, K>(): { \
|
| 166 |
-
launch_gemm_legacy<M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, 128, SPLITK_FACTOR>(a_ptr, b_ptr, c_ptr, as_ptr, bs_ptr, job_stream0); \
|
| 167 |
-
break; \
|
| 168 |
-
}
|
| 169 |
-
|
| 170 |
-
template <int M, int N, int K> constexpr inline uint32_t pack_shape_checked() {
|
| 171 |
-
static_assert(M % 32 == 0, "M must be a multiple of 32");
|
| 172 |
-
static_assert(N % 32 == 0, "N must be a multiple of 32");
|
| 173 |
-
static_assert(K % 32 == 0, "K must be a multiple of 32");
|
| 174 |
-
static_assert(M >= 32 && M <= 8192, "M must be between 32 and 8192");
|
| 175 |
-
static_assert(N >= 32 && N <= 8192, "N must be between 32 and 8192");
|
| 176 |
-
static_assert(K >= 32 && K <= 8192, "K must be between 32 and 8192");
|
| 177 |
-
return pack_shape(M, N, K);
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
extern "C" {
|
| 183 |
-
// Basically, it dispatch GEMM to fatest implementations according to inputs' shape.
|
| 184 |
-
void run(void *a, void *b, void *as, void *bs, void *c, int m, int n, int k, PerfMetrics *metrics, hipStream_t job_stream0) {
|
| 185 |
-
// Cast pointers once
|
| 186 |
-
const __FP8_TYPE *a_ptr = static_cast<const __FP8_TYPE *>(a);
|
| 187 |
-
const __FP8_TYPE *b_ptr = static_cast<const __FP8_TYPE *>(b);
|
| 188 |
-
__BF16_TYPE *c_ptr = static_cast<__BF16_TYPE *>(c);
|
| 189 |
-
const float *as_ptr = static_cast<const float *>(as);
|
| 190 |
-
const float *bs_ptr = static_cast<const float *>(bs);
|
| 191 |
-
KernelTimerScoped timer(timers, 2LL * m * n * k,
|
| 192 |
-
metrics ? &metrics->entries[0].time : nullptr,
|
| 193 |
-
metrics ? &metrics->entries[0].gflops : nullptr, job_stream0);
|
| 194 |
-
|
| 195 |
-
switch (pack_shape(m, n, k)) {
|
| 196 |
-
#ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
|
| 197 |
-
// Test: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE
|
| 198 |
-
DISPATCH_GEMM(64, 64, 128, 64, 64, 32, 1, 4, 128, 1, 16);
|
| 199 |
-
DISPATCH_GEMM(64, 1536, 7168, 64, 128, 64, 4, 2, 256, 1, 16);
|
| 200 |
-
DISPATCH_GEMM(64, 3072, 1536, 64, 128, 64, 4, 2, 256, 1, 16);
|
| 201 |
-
DISPATCH_GEMM(64, 576, 7168, 64, 128, 64, 4, 2, 256, 1, 16);
|
| 202 |
-
DISPATCH_GEMM(96, 7168, 256, 96, 256, 64, 2, 4, 256, 1, 16);
|
| 203 |
-
DISPATCH_GEMM(96, 7168, 2048, 96, 256, 64, 2, 4, 256, 1, 16);
|
| 204 |
-
DISPATCH_GEMM(96, 4608, 7168, 96, 256, 64, 2, 4, 256, 1, 16);
|
| 205 |
-
DISPATCH_GEMM(128, 7168, 2304, 128, 128, 64, 4, 2, 256, 1, 16);
|
| 206 |
-
DISPATCH_GEMM(128, 512, 7168, 128, 128, 64, 4, 2, 256, 1, 16);
|
| 207 |
-
DISPATCH_GEMM(512, 4096, 512, 256, 128, 64, 4, 2, 256, 1, 16);
|
| 208 |
-
DISPATCH_GEMM(512, 1536, 7168, 256, 128, 64, 4, 2, 256, 1, 16);
|
| 209 |
-
// Benchmark: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE
|
| 210 |
-
DISPATCH_GEMM(1024, 1536, 7168, 128, 128, 64, 1, 4, 128, 4, 16); // Optimized: 0.49 ms (45.65 TFlops)
|
| 211 |
-
DISPATCH_GEMM(1024, 3072, 1536, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.19 ms (51.32 TFlops)
|
| 212 |
-
DISPATCH_GEMM(1024, 576, 7168, 128, 64, 32, 4, 1, 128, 4, 16); // Optimized: 0.30 ms (28.16 TFlops)
|
| 213 |
-
DISPATCH_GEMM(1024, 7168, 256, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.08 ms (46.49 TFlops)
|
| 214 |
-
DISPATCH_GEMM(1024, 7168, 2048, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.49 ms (61.92 TFlops)
|
| 215 |
-
DISPATCH_GEMM(1024, 4608, 7168, 128, 128, 32, 2, 2, 128, 1, 16); // Optimized: 0.99 ms (68.16 TFlops)
|
| 216 |
-
DISPATCH_GEMM(1024, 7168, 2304, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.51 ms (66.04 TFlops)
|
| 217 |
-
DISPATCH_GEMM(1024, 512, 7168, 64, 128, 32, 2, 2, 128, 4, 16); // Optimized: 0.26 ms (28.97 TFlops)
|
| 218 |
-
DISPATCH_GEMM(1024, 4096, 512, 128, 256, 32, 2, 4, 256, 1, 16); // Optimized: 0.08 ms (54.27 TFlops)
|
| 219 |
-
DISPATCH_GEMM(6144, 1536, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 1.76 ms (76.76 TFlops)
|
| 220 |
-
DISPATCH_GEMM(6144, 3072, 1536, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.88 ms (66.00 TFlops)
|
| 221 |
-
DISPATCH_GEMM(6144, 576, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.84 ms (60.68 TFlops)
|
| 222 |
-
DISPATCH_GEMM(6144, 7168, 256, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.49 ms (45.76 TFlops)
|
| 223 |
-
DISPATCH_GEMM(6144, 7168, 2048, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 2.17 ms (83.11 TFlops)
|
| 224 |
-
DISPATCH_GEMM(6144, 4608, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 4.56 ms (88.99 TFlops)
|
| 225 |
-
DISPATCH_GEMM(6144, 7168, 2304, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 2.41 ms (84.32 TFlops)
|
| 226 |
-
DISPATCH_GEMM(6144, 512, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.67 ms (67.45 TFlops)
|
| 227 |
-
DISPATCH_GEMM(6144, 4096, 512, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.51 ms (50.79 TFlops)
|
| 228 |
-
#else // CDNA3, WAVE_SIZE = 64
|
| 229 |
-
// Benchmark: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SZ, SPLITK_F, LOAD_BS
|
| 230 |
-
DISPATCH_GEMM(1024, 1536, 7168, 256, 128, 128, 4, 2, 512, 4, 16); // #0
|
| 231 |
-
DISPATCH_GEMM(1024, 3072, 1536, 256, 128, 128, 4, 2, 512, 2, 16); // #1
|
| 232 |
-
DISPATCH_GEMM(1024, 576, 7168, 256, 128, 128, 4, 2, 512, 8, 16); // #2
|
| 233 |
-
DISPATCH_GEMM(1024, 7168, 256, 256, 128, 128, 4, 2, 512, 1, 16); // #3
|
| 234 |
-
DISPATCH_GEMM(1024, 7168, 2048, 256, 128, 128, 4, 2, 512, 1, 16); // #4
|
| 235 |
-
DISPATCH_GEMM(1024, 4608, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #5
|
| 236 |
-
DISPATCH_GEMM(1024, 7168, 2304, 256, 128, 128, 4, 2, 512, 1, 16); // #6
|
| 237 |
-
DISPATCH_GEMM(1024, 512, 7168, 256, 128, 128, 4, 2, 512, 8, 16); // #7
|
| 238 |
-
DISPATCH_GEMM(1024, 4096, 512, 256, 128, 128, 4, 2, 512, 1, 16); // #8
|
| 239 |
-
DISPATCH_GEMM(6144, 1536, 7168, 256, 128, 128, 4, 2, 512, 1, 16); // #9
|
| 240 |
-
DISPATCH_GEMM(6144, 3072, 1536, 256, 128, 128, 4, 2, 512, 1, 16); // #10
|
| 241 |
-
DISPATCH_GEMM(6144, 576, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #11
|
| 242 |
-
DISPATCH_GEMM(6144, 7168, 256, 256, 128, 128, 4, 2, 512, 1, 16); // #12
|
| 243 |
-
DISPATCH_GEMM(6144, 7168, 2048, 256, 128, 128, 4, 2, 512, 1, 16); // #13
|
| 244 |
-
DISPATCH_GEMM(6144, 4608, 7168, 256, 128, 128, 4, 2, 512, 1, 16); // #14
|
| 245 |
-
DISPATCH_GEMM(6144, 7168, 2304, 256, 128, 128, 4, 2, 512, 1, 16); // #15
|
| 246 |
-
DISPATCH_GEMM(6144, 512, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #16
|
| 247 |
-
DISPATCH_GEMM(6144, 4096, 512, 256, 128, 128, 4, 2, 512, 1, 16); // #17
|
| 248 |
-
// Test: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SZ, SPLITK_F,
|
| 249 |
-
DISPATCH_GEMM_LEGACY(64, 64, 128, 64, 64, 32, 4, 2, 512, 1);
|
| 250 |
-
DISPATCH_GEMM_LEGACY(64, 1536, 7168, 64, 128, 64, 4, 2, 512, 1);
|
| 251 |
-
DISPATCH_GEMM_LEGACY(64, 3072, 1536, 64, 128, 64, 4, 2, 512, 1);
|
| 252 |
-
DISPATCH_GEMM_LEGACY(64, 576, 7168, 64, 128, 64, 4, 2, 512, 1);
|
| 253 |
-
DISPATCH_GEMM_LEGACY(96, 7168, 256, 96, 256, 64, 2, 4, 512, 1);
|
| 254 |
-
DISPATCH_GEMM_LEGACY(96, 7168, 2048, 96, 256, 64, 2, 4, 512, 1);
|
| 255 |
-
DISPATCH_GEMM_LEGACY(96, 4608, 7168, 96, 256, 64, 2, 4, 512, 1);
|
| 256 |
-
DISPATCH_GEMM_LEGACY(128, 7168, 2304, 128, 128, 64, 4, 2, 512, 1);
|
| 257 |
-
DISPATCH_GEMM_LEGACY(128, 512, 7168, 128, 128, 64, 4, 2, 512, 1);
|
| 258 |
-
DISPATCH_GEMM_LEGACY(512, 4096, 512, 256, 128, 64, 4, 2, 512, 1);
|
| 259 |
-
DISPATCH_GEMM_LEGACY(512, 1536, 7168, 256, 128, 64, 4, 2, 512, 1);
|
| 260 |
-
#endif
|
| 261 |
-
default: {
|
| 262 |
-
printf("Error: Unsupported shape M=%d, K=%d, N=%d\n", m, k, n);
|
| 263 |
-
abort();
|
| 264 |
-
}
|
| 265 |
-
}
|
| 266 |
-
}
|
| 267 |
-
} // extern "C"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|