#include #include "apis/attention.hpp" #include "apis/einsum.hpp" #include "apis/gemm.hpp" #include "apis/hyperconnection.hpp" #include "apis/layout.hpp" #if DG_TENSORMAP_COMPATIBLE #include "jit/compiler.hpp" #endif #include "jit/device_runtime.hpp" // ---- Type conversion helpers ---- static std::optional> vec_to_tuple3(const std::optional>& v) { if (!v.has_value()) return std::nullopt; const auto& vec = v.value(); return std::make_tuple(static_cast(vec[0]), static_cast(vec[1]), static_cast(vec[2])); } static std::tuple vec_to_tuple3_req(const std::vector& v) { return std::make_tuple(static_cast(v[0]), static_cast(v[1]), static_cast(v[2])); } static std::optional> vec_to_tuple2(const std::optional>& v) { if (!v.has_value()) return std::nullopt; const auto& vec = v.value(); return std::make_tuple(static_cast(vec[0]), static_cast(vec[1])); } static std::vector vec64_to_vec32(const std::vector& v) { return std::vector(v.begin(), v.end()); } // ---- Runtime APIs ---- void dg_init(const std::string& library_root_path, const std::string& cuda_home_path, const std::string& cutlass_include_path) { #if DG_TENSORMAP_COMPATIBLE deep_gemm::Compiler::prepare_init(library_root_path, cuda_home_path, cutlass_include_path); deep_gemm::KernelRuntime::prepare_init(cuda_home_path); #endif } void dg_set_num_sms(int64_t new_num_sms) { deep_gemm::device_runtime->set_num_sms(static_cast(new_num_sms)); } int64_t dg_get_num_sms() { return static_cast(deep_gemm::device_runtime->get_num_sms()); } void dg_set_tc_util(int64_t new_tc_util) { deep_gemm::device_runtime->set_tc_util(static_cast(new_tc_util)); } int64_t dg_get_tc_util() { return static_cast(deep_gemm::device_runtime->get_tc_util()); } // ---- cuBLASLt GEMMs ---- void dg_cublaslt_gemm_nt(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c) { deep_gemm::gemm::cublaslt_gemm_nt(a, b, d, c); } void dg_cublaslt_gemm_nn(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c) { deep_gemm::gemm::cublaslt_gemm_nn(a, b, d, c); } void dg_cublaslt_gemm_tn(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c) { deep_gemm::gemm::cublaslt_gemm_tn(a, b, d, c); } void dg_cublaslt_gemm_tt(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c) { deep_gemm::gemm::cublaslt_gemm_tt(a, b, d, c); } // ---- FP8/FP4 GEMMs ---- #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE void dg_fp8_fp4_gemm_nt(const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::optional& c, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast) { deep_gemm::gemm::fp8_fp4_gemm_nt( {a, sfa}, {b, sfb}, d, c, vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast); } void dg_fp8_fp4_gemm_nn(const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::optional& c, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast) { deep_gemm::gemm::fp8_fp4_gemm_nn( {a, sfa}, {b, sfb}, d, c, vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast); } void dg_fp8_fp4_gemm_tn(const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::optional& c, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast) { deep_gemm::gemm::fp8_fp4_gemm_tn( {a, sfa}, {b, sfb}, d, c, vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast); } void dg_fp8_fp4_gemm_tt(const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::optional& c, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast) { deep_gemm::gemm::fp8_fp4_gemm_tt( {a, sfa}, {b, sfb}, d, c, vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast); } void dg_m_grouped_fp8_fp4_gemm_nt_contiguous( const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const at::Tensor& grouped_layout, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout, const std::optional& expected_m_for_psum_layout) { std::optional expected_m; if (expected_m_for_psum_layout.has_value()) expected_m = static_cast(expected_m_for_psum_layout.value()); deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_contiguous( {a, sfa}, {b, sfb}, d, grouped_layout, vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast, use_psum_layout, expected_m); } void dg_m_grouped_fp8_fp4_gemm_nn_contiguous( const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const at::Tensor& grouped_layout, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout) { deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nn_contiguous( {a, sfa}, {b, sfb}, d, grouped_layout, vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast, use_psum_layout); } void dg_m_grouped_fp8_fp4_gemm_nt_masked( const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const at::Tensor& masked_m, int64_t expected_m, const std::optional>& recipe, const std::optional>& recipe_a, const std::optional>& recipe_b, const std::string& compiled_dims, bool disable_ue8m0_cast) { deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_masked( {a, sfa}, {b, sfb}, d, masked_m, static_cast(expected_m), vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), compiled_dims, disable_ue8m0_cast); } void dg_k_grouped_fp8_gemm_nt_contiguous( const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::vector& ks, const at::Tensor& ks_tensor, const std::optional& c, const std::vector& recipe, const std::string& compiled_dims) { deep_gemm::gemm::k_grouped_fp8_gemm_nt_contiguous( {a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c, vec_to_tuple3_req(recipe), compiled_dims); } void dg_k_grouped_fp8_gemm_tn_contiguous( const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::vector& ks, const at::Tensor& ks_tensor, const std::optional& c, const std::vector& recipe, const std::string& compiled_dims) { deep_gemm::gemm::k_grouped_fp8_gemm_tn_contiguous( {a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c, vec_to_tuple3_req(recipe), compiled_dims); } #endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE // ---- BF16 GEMMs ---- #if DG_TENSORMAP_COMPATIBLE void dg_bf16_gemm_nt(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c, const std::string& compiled_dims) { deep_gemm::gemm::bf16_gemm_nt(a, b, d, c, compiled_dims); } void dg_bf16_gemm_nn(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c, const std::string& compiled_dims) { deep_gemm::gemm::bf16_gemm_nn(a, b, d, c, compiled_dims); } void dg_bf16_gemm_tn(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c, const std::string& compiled_dims) { deep_gemm::gemm::bf16_gemm_tn(a, b, d, c, compiled_dims); } void dg_bf16_gemm_tt(const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c, const std::string& compiled_dims) { deep_gemm::gemm::bf16_gemm_tt(a, b, d, c, compiled_dims); } void dg_m_grouped_bf16_gemm_nt_contiguous( const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const at::Tensor& grouped_layout, const std::string& compiled_dims, bool use_psum_layout, const std::optional& expected_m_for_psum_layout) { std::optional expected_m; if (expected_m_for_psum_layout.has_value()) expected_m = static_cast(expected_m_for_psum_layout.value()); deep_gemm::gemm::m_grouped_bf16_gemm_nt_contiguous( a, b, d, grouped_layout, compiled_dims, use_psum_layout, expected_m); } void dg_m_grouped_bf16_gemm_nn_contiguous( const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const at::Tensor& grouped_layout, const std::string& compiled_dims, bool use_psum_layout) { deep_gemm::gemm::m_grouped_bf16_gemm_nn_contiguous( a, b, d, grouped_layout, compiled_dims, use_psum_layout); } void dg_m_grouped_bf16_gemm_nt_masked( const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const at::Tensor& masked_m, int64_t expected_m, const std::string& compiled_dims) { deep_gemm::gemm::m_grouped_bf16_gemm_nt_masked( a, b, d, masked_m, static_cast(expected_m), compiled_dims); } void dg_k_grouped_bf16_gemm_tn_contiguous( const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::vector& ks, const at::Tensor& ks_tensor, const std::optional& c, const std::string& compiled_dims) { deep_gemm::gemm::k_grouped_bf16_gemm_tn_contiguous( a, b, d, vec64_to_vec32(ks), ks_tensor, c, compiled_dims); } #endif // DG_TENSORMAP_COMPATIBLE // ---- Einsum ---- #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE void dg_einsum(const std::string& expr, const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const std::optional& c, bool use_cublaslt) { deep_gemm::einsum::einsum(expr, a, b, d, c, use_cublaslt); } void dg_fp8_einsum(const std::string& expr, const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::optional& c, const std::vector& recipe) { deep_gemm::einsum::fp8_einsum( expr, {a, sfa}, {b, sfb}, d, c, vec_to_tuple3_req(recipe)); } #endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE // ---- Attention ---- #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE void dg_fp8_gemm_nt_skip_head_mid( const at::Tensor& a, const at::Tensor& sfa, const at::Tensor& b, const at::Tensor& sfb, const at::Tensor& d, const std::vector& head_splits, const std::optional>& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast) { deep_gemm::attention::fp8_gemm_nt_skip_head_mid( {a, sfa}, {b, sfb}, d, vec_to_tuple3_req(head_splits), vec_to_tuple3(recipe), compiled_dims, disable_ue8m0_cast); } at::Tensor dg_fp8_mqa_logits( const at::Tensor& q, const at::Tensor& kv, const at::Tensor& kv_sf, const at::Tensor& weights, const at::Tensor& cu_seq_len_k_start, const at::Tensor& cu_seq_len_k_end, bool clean_logits, int64_t max_seqlen_k) { return deep_gemm::attention::fp8_mqa_logits( q, {kv, kv_sf}, weights, cu_seq_len_k_start, cu_seq_len_k_end, clean_logits, static_cast(max_seqlen_k)); } at::Tensor dg_get_paged_mqa_logits_metadata( const at::Tensor& context_lens, int64_t block_kv, int64_t num_sms) { return deep_gemm::attention::get_paged_mqa_logits_metadata( context_lens, static_cast(block_kv), static_cast(num_sms)); } at::Tensor dg_fp8_paged_mqa_logits( const at::Tensor& q, const at::Tensor& fused_kv_cache, const at::Tensor& weights, const at::Tensor& context_lens, const at::Tensor& block_table, const at::Tensor& schedule_meta, int64_t max_context_len, bool clean_logits) { return deep_gemm::attention::fp8_paged_mqa_logits( q, fused_kv_cache, weights, context_lens, block_table, schedule_meta, static_cast(max_context_len), clean_logits); } #endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE // ---- Hyperconnection ---- #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE void dg_tf32_hc_prenorm_gemm( const at::Tensor& a, const at::Tensor& b, const at::Tensor& d, const at::Tensor& sqr_sum, const std::optional& num_splits) { std::optional ns; if (num_splits.has_value()) ns = static_cast(num_splits.value()); deep_gemm::hyperconnection::tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns); } #endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE // ---- Layout ---- #if DG_TENSORMAP_COMPATIBLE at::Tensor dg_transform_sf_into_required_layout( const at::Tensor& sf, int64_t mn, int64_t k, const std::optional>& recipe, const std::optional>& recipe_ab, const std::optional& num_groups, bool is_sfa, bool disable_ue8m0_cast) { std::optional ng; if (num_groups.has_value()) ng = static_cast(num_groups.value()); return deep_gemm::layout::transform_sf_into_required_layout( sf, static_cast(mn), static_cast(k), vec_to_tuple3(recipe), vec_to_tuple2(recipe_ab), ng, is_sfa, disable_ue8m0_cast); } int64_t dg_get_tma_aligned_size(int64_t x, int64_t element_size) { return static_cast( deep_gemm::get_tma_aligned_size( static_cast(x), static_cast(element_size))); } at::Tensor dg_get_mn_major_tma_aligned_tensor(const at::Tensor& sf) { return deep_gemm::get_mn_major_tma_aligned_tensor(sf); } at::Tensor dg_get_mn_major_tma_aligned_packed_ue8m0_tensor(const at::Tensor& sf) { return deep_gemm::get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); } at::Tensor dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( const at::Tensor& sf, const at::Tensor& ks_tensor, const std::vector& ks) { return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( sf, ks_tensor, vec64_to_vec32(ks)); } #endif // DG_TENSORMAP_COMPATIBLE int64_t dg_get_mk_alignment_for_contiguous_layout() { return static_cast(deep_gemm::get_mk_alignment_for_contiguous_layout()); }