File size: 11,807 Bytes
c67ae40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | #pragma once
#include <cuda.h>
#include <torch/torch.h>
#include "../heuristics/sm90.hpp"
#include "../../jit/handle.hpp"
#include "../../utils/math.hpp"
#include "../../utils/system.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
static std::pair<int, int> get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) {
return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k);
}
static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
return major == cute::UMMA::Major::K ? -2 : -1;
}
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
for (const char& c: compiled_dims) {
if (name == c)
return dim;
}
return 0;
}
static std::string to_string(const cute::UMMA::Major& major) {
switch (major) {
case cute::UMMA::Major::K: return "cute::UMMA::Major::K";
case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN";
}
DG_HOST_UNREACHABLE("Unknown major");
}
static std::string to_string(const GemmType& type) {
switch (type) {
case GemmType::Normal: return "GemmType::Normal";
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout";
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
case GemmType::Batched: return "GemmType::Batched";
}
DG_HOST_UNREACHABLE("Unknown GEMM type");
}
static std::string to_string(const at::ScalarType& dtype) {
switch (dtype) {
case torch::kInt: return "int";
case torch::kFloat: return "float";
case torch::kBFloat16: return "cutlass::bfloat16_t";
case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t";
case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t";
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
const bool& allow_tf32) {
if (allow_tf32 and dtype == torch::kFloat)
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
switch (dtype) {
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#if CUDART_VERSION >= 12080
case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;
#endif
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
#if CUDART_VERSION >= 12080
if (base != 0) {
DG_HOST_ASSERT(base == 32 and mode == 128);
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
}
#endif
DG_HOST_ASSERT(base == 0);
switch (mode) {
case 0:
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
default: DG_HOST_UNREACHABLE("Unsupported swizzling mode");
}
}
static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int gmem_inner_dim, int gmem_outer_dim,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
// Inner dim must be a multiple of 64B for .b4x16_p64
if (t.scalar_type() == kPackedFP4)
DG_HOST_ASSERT(gmem_inner_dim % 128 == 0);
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
}
DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
int gmem_dim_0, int gmem_dim_1, int gmem_dim_2,
int smem_dim_0, int smem_dim_1, int smem_dim_2,
const int& gmem_stride_0, const int& gmem_stride_1,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_dim_0 = swizzle_mode / elem_size;
// Inner dim must be a multiple of 64B for .b4x16_p64
if (t.scalar_type() == kPackedFP4)
DG_HOST_ASSERT(gmem_dim_0 % 128 == 0);
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
const cuuint32_t elem_strides[3] = {1, 1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_m, const int& shape_k,
const int& block_m, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_n, const int& shape_k,
const int& block_n, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
// `num_groups` is always applied into the outer dimensions
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim * num_groups,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
const int& shape_m, const int& shape_n,
const int& block_m, const int& block_n,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
return make_tma_2d_desc(t,
shape_n, shape_m * num_groups,
block_n, block_m,
outer_stride,
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
int shape_mn, int shape_k,
const int& block_mn, const int& gran_k,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
// TODO: maybe swizzle SF as well
DG_HOST_ASSERT(swizzle_mode == 0);
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
return make_tma_2d_desc(t,
shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode, swizzle_base,
allow_tf32);
}
} // namespace deep_gemm
|