File size: 6,385 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
#pragma once
#include <cutlass/util/packed_stride.hpp>
namespace at::cuda::detail {
using Strides = std::array<int64_t, 3>;
template <
typename DtypeA,
typename DtypeB,
typename DtypeOutput,
typename DtypeScale,
typename ProblemShape,
typename StrideA,
typename StrideB,
typename StrideOutput>
__global__ void prepare_grouped_gemm_data(
DtypeA* A,
DtypeB* B,
DtypeOutput* output,
DtypeScale* scale_A,
DtypeScale* scale_B,
DtypeA** A_ptrs,
DtypeB** B_ptrs,
DtypeOutput** output_ptrs,
DtypeScale** inputA_scale_ptrs,
DtypeScale** inputB_scale_ptrs,
ProblemShape* problem_sizes,
// Strides for cutlass, cute::Stride
StrideA* stride_A,
StrideB* stride_B,
StrideOutput* stride_output,
const int32_t* offs,
int32_t M,
int32_t N,
int32_t K,
// Original strides of the input tensors
Strides tensor_StrideA,
Strides tensor_StrideB,
Strides tensor_StrideOutput,
int64_t a_scale_stride,
int64_t b_scale_stride,
bool a_row_major = true,
bool b_row_major = false) {
int32_t tid = threadIdx.x;
int32_t delta = 0;
if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start;
if (K < 0) {
if (!a_row_major && b_row_major) {
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
} else {
// CUTLASS cannot handle delta=0 here.
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
}
}
// TMA transfers require global memory tensor addresses to be
// aligned to 16 bytes.
if (tid < blockDim.x - 1) {
// Check this requirement for input tensors, in case group
// addresses are increased along the dynamic dimension.
if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension
(M < 0 && !a_row_major)) { // 3D/2D: check along N dimension
int align = 128 / cutlass::sizeof_bits<DtypeA>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension
(N < 0 && b_row_major)) { // 3D/2D: check along N dimension
int align = 128 / cutlass::sizeof_bits<DtypeB>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
// Check the same requirement for output tensor (that is always
// contiguous, and in row-major layout).
if (N < 0) {
int align = 128 / cutlass::sizeof_bits<DtypeOutput>::value;
CUDA_KERNEL_ASSERT(
delta % align == 0 &&
"expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n");
}
}
}
int64_t lda, ldb, ldoutput;
if (M < 0) {
// A and output is 2d
M = delta;
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
ldoutput = tensor_StrideOutput[0];
A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[0];
if (scale_A != nullptr) {
inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1];
inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride;
}
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
B_ptrs[tid] = B + tid * tensor_StrideB[0];
} else if (N < 0) {
N = delta;
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed
ldoutput = tensor_StrideOutput[0];
A_ptrs[tid] = A + tid * tensor_StrideA[0];
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1];
B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[1];
if (scale_A != nullptr) {
inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride;
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
}
} else if (K < 0) {
// A, B is 2d, output is 3d
K = delta;
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
ldoutput = tensor_StrideOutput[1];
A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[1];
B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[0];
output_ptrs[tid] = output + tid * tensor_StrideOutput[0];
if (scale_A != nullptr) {
inputA_scale_ptrs[tid] = scale_A + tid * M;
inputB_scale_ptrs[tid] = scale_B + tid * N;
}
} else {
// A, B, output are 3D
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
ldoutput = tensor_StrideOutput[1];
A_ptrs[tid] = A + tid * tensor_StrideA[0];
B_ptrs[tid] = B + tid * tensor_StrideB[0];
output_ptrs[tid] = output + tid * tensor_StrideOutput[0];
if (scale_A != nullptr) {
inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride;
inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride;
}
}
problem_sizes[tid] = ProblemShape(M, N, K);
// make_cute_packed_stride only replaces one of the stride elements with
// one the provided values in the shape arguments
// the indices of the src/dst depend on whether A/B are row-major
// so constructing shape argument with two similar lda values
// while it looks non-sensical (and it is a nonsensical shape)
// is fine for these stride construction purposes - the one that will be used
// for replacement is correct, the other one is ignored, and we don't have to
// branch on whether A/B are row-major
stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1});
stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1});
stride_output[tid] =
cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1});
}
} // namespace at::cuda::detail
|