#pragma once #include #include #include "math.hpp" #include "exception.hpp" #include "../jit/device_runtime.hpp" namespace deep_gemm { // Major-ness stuffs static void major_check(const torch::Tensor& t) { const auto dim = t.dim(); DG_HOST_ASSERT(dim == 2 or dim == 3); if (dim == 3) DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1)); DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1); } static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) { major_check(t); return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; } static void check_major_type_cd(const torch::Tensor& t) { // NOTES: the library only supports row-major output layouts major_check(t); DG_HOST_ASSERT(t.stride(-1) == 1); } static bool fp8_requires_k_major() { return device_runtime->get_arch_major() == 9; } // Tensor utils template static auto get_shape(const torch::Tensor& t) { DG_HOST_ASSERT(t.dim() == N); return [&t] (std::index_sequence) { return std::make_tuple(static_cast(t.sizes()[Is])...); }(std::make_index_sequence()); } static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { auto [mn, k] = get_shape<2>(ab); if (ab.scalar_type() != torch::kFloat8_e4m3fn) { DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); } return std::make_tuple(mn, k); } static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { auto [num_groups, mn, k] = get_shape<3>(ab); if (ab.scalar_type() != torch::kFloat8_e4m3fn) { DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); } return std::make_tuple(num_groups, mn, k); } // Recipe static std::tuple get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); return {1, 128, 128}; } else if (arch_major == 10) { DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt); return sfb_dtype == torch::kFloat ? std::make_tuple(1, 128, 128): // Legacy format std::make_tuple(1, 1, 128); // 1D1D kernels } DG_HOST_UNREACHABLE("Unknown recipe"); } // SF layouts static torch::Tensor check_sf_layout(const torch::Tensor& sf, const int& mn, const int& k, const int& gran_mn, const int& gran_k, const std::optional& num_groups, const bool& tma_stride_check = false, const bool& sm90_sfb_check = false, const std::optional& type_check = std::nullopt) { // Type check if (type_check.has_value()) DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); // Always do shape checks const auto sf_dtype = sf.scalar_type(); DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); if (num_groups.has_value()) DG_HOST_ASSERT(sf.size(-3) == num_groups.value()); DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn)); DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4))); // TMA stride checks: TMA aligned and MN-major if (tma_stride_check) { if (num_groups.has_value()) DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); // Check contiguity in the MN direction DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1); DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); } // SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions if (sm90_sfb_check) { if (num_groups.has_value()) DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1)); DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or (sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1)); } return sf; } // Value matrix layout static int get_mk_alignment_for_contiguous_layout() { return 128; } } // namespace deep_gemm