Phi2-Fine-Tuning
/
phivenv
/Lib
/site-packages
/torch
/include
/ATen
/native
/cuda
/GroupMMCommon.cuh
| namespace at::cuda::detail { | |
| using Strides = std::array<int64_t, 3>; | |
| template < | |
| typename DtypeA, | |
| typename DtypeB, | |
| typename DtypeOutput, | |
| typename DtypeScale, | |
| typename ProblemShape, | |
| typename StrideA, | |
| typename StrideB, | |
| typename StrideOutput> | |
| __global__ void prepare_grouped_gemm_data( | |
| DtypeA* A, | |
| DtypeB* B, | |
| DtypeOutput* output, | |
| DtypeScale* scale_A, | |
| DtypeScale* scale_B, | |
| DtypeA** A_ptrs, | |
| DtypeB** B_ptrs, | |
| DtypeOutput** output_ptrs, | |
| DtypeScale** inputA_scale_ptrs, | |
| DtypeScale** inputB_scale_ptrs, | |
| ProblemShape* problem_sizes, | |
| // Strides for cutlass, cute::Stride | |
| StrideA* stride_A, | |
| StrideB* stride_B, | |
| StrideOutput* stride_output, | |
| const int32_t* offs, | |
| int32_t M, | |
| int32_t N, | |
| int32_t K, | |
| // Original strides of the input tensors | |
| Strides tensor_StrideA, | |
| Strides tensor_StrideB, | |
| Strides tensor_StrideOutput, | |
| int64_t a_scale_stride, | |
| int64_t b_scale_stride, | |
| bool a_row_major = true, | |
| bool b_row_major = false) { | |
| int32_t tid = threadIdx.x; | |
| int32_t delta = 0; | |
| if (offs != nullptr) { | |
| int32_t start = tid == 0 ? 0 : offs[tid - 1]; | |
| delta = offs[tid] - start; | |
| if (K < 0) { | |
| if (!a_row_major && b_row_major) { | |
| CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); | |
| } else { | |
| // CUTLASS cannot handle delta=0 here. | |
| CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); | |
| } | |
| } | |
| // TMA transfers require global memory tensor addresses to be | |
| // aligned to 16 bytes. | |
| if (tid < blockDim.x - 1) { | |
| // Check this requirement for input tensors, in case group | |
| // addresses are increased along the dynamic dimension. | |
| if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension | |
| (M < 0 && !a_row_major)) { // 3D/2D: check along N dimension | |
| int align = 128 / cutlass::sizeof_bits<DtypeA>::value; | |
| CUDA_KERNEL_ASSERT( | |
| delta % align == 0 && | |
| "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); | |
| } | |
| if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension | |
| (N < 0 && b_row_major)) { // 3D/2D: check along N dimension | |
| int align = 128 / cutlass::sizeof_bits<DtypeB>::value; | |
| CUDA_KERNEL_ASSERT( | |
| delta % align == 0 && | |
| "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); | |
| } | |
| // Check the same requirement for output tensor (that is always | |
| // contiguous, and in row-major layout). | |
| if (N < 0) { | |
| int align = 128 / cutlass::sizeof_bits<DtypeOutput>::value; | |
| CUDA_KERNEL_ASSERT( | |
| delta % align == 0 && | |
| "expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n"); | |
| } | |
| } | |
| } | |
| int64_t lda, ldb, ldoutput; | |
| if (M < 0) { | |
| // A and output is 2d | |
| M = delta; | |
| lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; | |
| ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; | |
| ldoutput = tensor_StrideOutput[0]; | |
| A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[0]; | |
| if (scale_A != nullptr) { | |
| inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; | |
| inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; | |
| } | |
| output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; | |
| B_ptrs[tid] = B + tid * tensor_StrideB[0]; | |
| } else if (N < 0) { | |
| N = delta; | |
| lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; | |
| ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed | |
| ldoutput = tensor_StrideOutput[0]; | |
| A_ptrs[tid] = A + tid * tensor_StrideA[0]; | |
| output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; | |
| B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[1]; | |
| if (scale_A != nullptr) { | |
| inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; | |
| inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; | |
| } | |
| } else if (K < 0) { | |
| // A, B is 2d, output is 3d | |
| K = delta; | |
| lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; | |
| ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; | |
| ldoutput = tensor_StrideOutput[1]; | |
| A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[1]; | |
| B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[0]; | |
| output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; | |
| if (scale_A != nullptr) { | |
| inputA_scale_ptrs[tid] = scale_A + tid * M; | |
| inputB_scale_ptrs[tid] = scale_B + tid * N; | |
| } | |
| } else { | |
| // A, B, output are 3D | |
| lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; | |
| ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; | |
| ldoutput = tensor_StrideOutput[1]; | |
| A_ptrs[tid] = A + tid * tensor_StrideA[0]; | |
| B_ptrs[tid] = B + tid * tensor_StrideB[0]; | |
| output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; | |
| if (scale_A != nullptr) { | |
| inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; | |
| inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; | |
| } | |
| } | |
| problem_sizes[tid] = ProblemShape(M, N, K); | |
| // make_cute_packed_stride only replaces one of the stride elements with | |
| // one the provided values in the shape arguments | |
| // the indices of the src/dst depend on whether A/B are row-major | |
| // so constructing shape argument with two similar lda values | |
| // while it looks non-sensical (and it is a nonsensical shape) | |
| // is fine for these stride construction purposes - the one that will be used | |
| // for replacement is correct, the other one is ignored, and we don't have to | |
| // branch on whether A/B are row-major | |
| stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); | |
| stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); | |
| stride_output[tid] = | |
| cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); | |
| } | |
| } // namespace at::cuda::detail | |