| | #include <torch/torch.h> |
| |
|
| | #include "apis/attention.hpp" |
| | #include "apis/einsum.hpp" |
| | #include "apis/gemm.hpp" |
| | #include "apis/hyperconnection.hpp" |
| | #include "apis/layout.hpp" |
| |
|
| | #if DG_TENSORMAP_COMPATIBLE |
| | #include "jit/compiler.hpp" |
| | #endif |
| | #include "jit/device_runtime.hpp" |
| |
|
| | |
| |
|
| | static std::optional<std::tuple<int, int, int>> |
| | vec_to_tuple3(const std::optional<std::vector<int64_t>>& v) { |
| | if (!v.has_value()) return std::nullopt; |
| | const auto& vec = v.value(); |
| | return std::make_tuple(static_cast<int>(vec[0]), |
| | static_cast<int>(vec[1]), |
| | static_cast<int>(vec[2])); |
| | } |
| |
|
| | static std::tuple<int, int, int> |
| | vec_to_tuple3_req(const std::vector<int64_t>& v) { |
| | return std::make_tuple(static_cast<int>(v[0]), |
| | static_cast<int>(v[1]), |
| | static_cast<int>(v[2])); |
| | } |
| |
|
| | static std::optional<std::tuple<int, int>> |
| | vec_to_tuple2(const std::optional<std::vector<int64_t>>& v) { |
| | if (!v.has_value()) return std::nullopt; |
| | const auto& vec = v.value(); |
| | return std::make_tuple(static_cast<int>(vec[0]), |
| | static_cast<int>(vec[1])); |
| | } |
| |
|
| | static std::vector<int> |
| | vec64_to_vec32(const std::vector<int64_t>& v) { |
| | return std::vector<int>(v.begin(), v.end()); |
| | } |
| |
|
| | |
| |
|
| | void dg_init(const std::string& library_root_path, |
| | const std::string& cuda_home_path, |
| | const std::string& cutlass_include_path) { |
| | #if DG_TENSORMAP_COMPATIBLE |
| | deep_gemm::Compiler::prepare_init(library_root_path, cuda_home_path, cutlass_include_path); |
| | deep_gemm::KernelRuntime::prepare_init(cuda_home_path); |
| | #endif |
| | } |
| |
|
| | void dg_set_num_sms(int64_t new_num_sms) { |
| | deep_gemm::device_runtime->set_num_sms(static_cast<int>(new_num_sms)); |
| | } |
| |
|
| | int64_t dg_get_num_sms() { |
| | return static_cast<int64_t>(deep_gemm::device_runtime->get_num_sms()); |
| | } |
| |
|
| | void dg_set_tc_util(int64_t new_tc_util) { |
| | deep_gemm::device_runtime->set_tc_util(static_cast<int>(new_tc_util)); |
| | } |
| |
|
| | int64_t dg_get_tc_util() { |
| | return static_cast<int64_t>(deep_gemm::device_runtime->get_tc_util()); |
| | } |
| |
|
| | |
| |
|
| | void dg_cublaslt_gemm_nt(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c) { |
| | deep_gemm::gemm::cublaslt_gemm_nt(a, b, d, c); |
| | } |
| |
|
| | void dg_cublaslt_gemm_nn(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c) { |
| | deep_gemm::gemm::cublaslt_gemm_nn(a, b, d, c); |
| | } |
| |
|
| | void dg_cublaslt_gemm_tn(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c) { |
| | deep_gemm::gemm::cublaslt_gemm_tn(a, b, d, c); |
| | } |
| |
|
| | void dg_cublaslt_gemm_tt(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c) { |
| | deep_gemm::gemm::cublaslt_gemm_tt(a, b, d, c); |
| | } |
| |
|
| | |
| |
|
| | #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE |
| |
|
| | void dg_fp8_fp4_gemm_nt(const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast) { |
| | deep_gemm::gemm::fp8_fp4_gemm_nt( |
| | {a, sfa}, {b, sfb}, d, c, |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast); |
| | } |
| |
|
| | void dg_fp8_fp4_gemm_nn(const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast) { |
| | deep_gemm::gemm::fp8_fp4_gemm_nn( |
| | {a, sfa}, {b, sfb}, d, c, |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast); |
| | } |
| |
|
| | void dg_fp8_fp4_gemm_tn(const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast) { |
| | deep_gemm::gemm::fp8_fp4_gemm_tn( |
| | {a, sfa}, {b, sfb}, d, c, |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast); |
| | } |
| |
|
| | void dg_fp8_fp4_gemm_tt(const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast) { |
| | deep_gemm::gemm::fp8_fp4_gemm_tt( |
| | {a, sfa}, {b, sfb}, d, c, |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast); |
| | } |
| |
|
| | void dg_m_grouped_fp8_fp4_gemm_nt_contiguous( |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, const at::Tensor& grouped_layout, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast, |
| | bool use_psum_layout, |
| | const std::optional<int64_t>& expected_m_for_psum_layout) { |
| | std::optional<int> expected_m; |
| | if (expected_m_for_psum_layout.has_value()) |
| | expected_m = static_cast<int>(expected_m_for_psum_layout.value()); |
| | deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_contiguous( |
| | {a, sfa}, {b, sfb}, d, grouped_layout, |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast, use_psum_layout, expected_m); |
| | } |
| |
|
| | void dg_m_grouped_fp8_fp4_gemm_nn_contiguous( |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, const at::Tensor& grouped_layout, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast, |
| | bool use_psum_layout) { |
| | deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nn_contiguous( |
| | {a, sfa}, {b, sfb}, d, grouped_layout, |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast, use_psum_layout); |
| | } |
| |
|
| | void dg_m_grouped_fp8_fp4_gemm_nt_masked( |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, const at::Tensor& masked_m, |
| | int64_t expected_m, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_a, |
| | const std::optional<std::vector<int64_t>>& recipe_b, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast) { |
| | deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_masked( |
| | {a, sfa}, {b, sfb}, d, masked_m, static_cast<int>(expected_m), |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b), |
| | compiled_dims, disable_ue8m0_cast); |
| | } |
| |
|
| | void dg_k_grouped_fp8_gemm_nt_contiguous( |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::vector<int64_t>& ks, |
| | const at::Tensor& ks_tensor, |
| | const std::optional<at::Tensor>& c, |
| | const std::vector<int64_t>& recipe, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::k_grouped_fp8_gemm_nt_contiguous( |
| | {a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c, |
| | vec_to_tuple3_req(recipe), compiled_dims); |
| | } |
| |
|
| | void dg_k_grouped_fp8_gemm_tn_contiguous( |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::vector<int64_t>& ks, |
| | const at::Tensor& ks_tensor, |
| | const std::optional<at::Tensor>& c, |
| | const std::vector<int64_t>& recipe, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::k_grouped_fp8_gemm_tn_contiguous( |
| | {a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c, |
| | vec_to_tuple3_req(recipe), compiled_dims); |
| | } |
| |
|
| | #endif |
| |
|
| | |
| |
|
| | #if DG_TENSORMAP_COMPATIBLE |
| |
|
| | void dg_bf16_gemm_nt(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::bf16_gemm_nt(a, b, d, c, compiled_dims); |
| | } |
| |
|
| | void dg_bf16_gemm_nn(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::bf16_gemm_nn(a, b, d, c, compiled_dims); |
| | } |
| |
|
| | void dg_bf16_gemm_tn(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::bf16_gemm_tn(a, b, d, c, compiled_dims); |
| | } |
| |
|
| | void dg_bf16_gemm_tt(const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::bf16_gemm_tt(a, b, d, c, compiled_dims); |
| | } |
| |
|
| | void dg_m_grouped_bf16_gemm_nt_contiguous( |
| | const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, const at::Tensor& grouped_layout, |
| | const std::string& compiled_dims, |
| | bool use_psum_layout, |
| | const std::optional<int64_t>& expected_m_for_psum_layout) { |
| | std::optional<int> expected_m; |
| | if (expected_m_for_psum_layout.has_value()) |
| | expected_m = static_cast<int>(expected_m_for_psum_layout.value()); |
| | deep_gemm::gemm::m_grouped_bf16_gemm_nt_contiguous( |
| | a, b, d, grouped_layout, compiled_dims, use_psum_layout, expected_m); |
| | } |
| |
|
| | void dg_m_grouped_bf16_gemm_nn_contiguous( |
| | const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, const at::Tensor& grouped_layout, |
| | const std::string& compiled_dims, |
| | bool use_psum_layout) { |
| | deep_gemm::gemm::m_grouped_bf16_gemm_nn_contiguous( |
| | a, b, d, grouped_layout, compiled_dims, use_psum_layout); |
| | } |
| |
|
| | void dg_m_grouped_bf16_gemm_nt_masked( |
| | const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, const at::Tensor& masked_m, |
| | int64_t expected_m, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::m_grouped_bf16_gemm_nt_masked( |
| | a, b, d, masked_m, static_cast<int>(expected_m), compiled_dims); |
| | } |
| |
|
| | void dg_k_grouped_bf16_gemm_tn_contiguous( |
| | const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::vector<int64_t>& ks, |
| | const at::Tensor& ks_tensor, |
| | const std::optional<at::Tensor>& c, |
| | const std::string& compiled_dims) { |
| | deep_gemm::gemm::k_grouped_bf16_gemm_tn_contiguous( |
| | a, b, d, vec64_to_vec32(ks), ks_tensor, c, compiled_dims); |
| | } |
| |
|
| | #endif |
| |
|
| | |
| |
|
| | #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE |
| |
|
| | void dg_einsum(const std::string& expr, |
| | const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | bool use_cublaslt) { |
| | deep_gemm::einsum::einsum(expr, a, b, d, c, use_cublaslt); |
| | } |
| |
|
| | void dg_fp8_einsum(const std::string& expr, |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::optional<at::Tensor>& c, |
| | const std::vector<int64_t>& recipe) { |
| | deep_gemm::einsum::fp8_einsum( |
| | expr, {a, sfa}, {b, sfb}, d, c, vec_to_tuple3_req(recipe)); |
| | } |
| |
|
| | #endif |
| |
|
| | |
| |
|
| | #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE |
| |
|
| | void dg_fp8_gemm_nt_skip_head_mid( |
| | const at::Tensor& a, const at::Tensor& sfa, |
| | const at::Tensor& b, const at::Tensor& sfb, |
| | const at::Tensor& d, |
| | const std::vector<int64_t>& head_splits, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::string& compiled_dims, |
| | bool disable_ue8m0_cast) { |
| | deep_gemm::attention::fp8_gemm_nt_skip_head_mid( |
| | {a, sfa}, {b, sfb}, d, vec_to_tuple3_req(head_splits), |
| | vec_to_tuple3(recipe), compiled_dims, disable_ue8m0_cast); |
| | } |
| |
|
| | at::Tensor dg_fp8_mqa_logits( |
| | const at::Tensor& q, |
| | const at::Tensor& kv, const at::Tensor& kv_sf, |
| | const at::Tensor& weights, |
| | const at::Tensor& cu_seq_len_k_start, |
| | const at::Tensor& cu_seq_len_k_end, |
| | bool clean_logits, |
| | int64_t max_seqlen_k) { |
| | return deep_gemm::attention::fp8_mqa_logits( |
| | q, {kv, kv_sf}, weights, |
| | cu_seq_len_k_start, cu_seq_len_k_end, |
| | clean_logits, static_cast<int>(max_seqlen_k)); |
| | } |
| |
|
| | at::Tensor dg_get_paged_mqa_logits_metadata( |
| | const at::Tensor& context_lens, |
| | int64_t block_kv, int64_t num_sms) { |
| | return deep_gemm::attention::get_paged_mqa_logits_metadata( |
| | context_lens, static_cast<int>(block_kv), static_cast<int>(num_sms)); |
| | } |
| |
|
| | at::Tensor dg_fp8_paged_mqa_logits( |
| | const at::Tensor& q, |
| | const at::Tensor& fused_kv_cache, |
| | const at::Tensor& weights, |
| | const at::Tensor& context_lens, |
| | const at::Tensor& block_table, |
| | const at::Tensor& schedule_meta, |
| | int64_t max_context_len, |
| | bool clean_logits) { |
| | return deep_gemm::attention::fp8_paged_mqa_logits( |
| | q, fused_kv_cache, weights, context_lens, |
| | block_table, schedule_meta, |
| | static_cast<int>(max_context_len), clean_logits); |
| | } |
| |
|
| | #endif |
| |
|
| | |
| |
|
| | #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE |
| |
|
| | void dg_tf32_hc_prenorm_gemm( |
| | const at::Tensor& a, const at::Tensor& b, |
| | const at::Tensor& d, const at::Tensor& sqr_sum, |
| | const std::optional<int64_t>& num_splits) { |
| | std::optional<int> ns; |
| | if (num_splits.has_value()) |
| | ns = static_cast<int>(num_splits.value()); |
| | deep_gemm::hyperconnection::tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns); |
| | } |
| |
|
| | #endif |
| |
|
| | |
| |
|
| | #if DG_TENSORMAP_COMPATIBLE |
| |
|
| | at::Tensor dg_transform_sf_into_required_layout( |
| | const at::Tensor& sf, |
| | int64_t mn, int64_t k, |
| | const std::optional<std::vector<int64_t>>& recipe, |
| | const std::optional<std::vector<int64_t>>& recipe_ab, |
| | const std::optional<int64_t>& num_groups, |
| | bool is_sfa, bool disable_ue8m0_cast) { |
| | std::optional<int> ng; |
| | if (num_groups.has_value()) |
| | ng = static_cast<int>(num_groups.value()); |
| | return deep_gemm::layout::transform_sf_into_required_layout( |
| | sf, static_cast<int>(mn), static_cast<int>(k), |
| | vec_to_tuple3(recipe), vec_to_tuple2(recipe_ab), |
| | ng, is_sfa, disable_ue8m0_cast); |
| | } |
| |
|
| | int64_t dg_get_tma_aligned_size(int64_t x, int64_t element_size) { |
| | return static_cast<int64_t>( |
| | deep_gemm::get_tma_aligned_size( |
| | static_cast<int>(x), static_cast<int>(element_size))); |
| | } |
| |
|
| | at::Tensor dg_get_mn_major_tma_aligned_tensor(const at::Tensor& sf) { |
| | return deep_gemm::get_mn_major_tma_aligned_tensor(sf); |
| | } |
| |
|
| | at::Tensor dg_get_mn_major_tma_aligned_packed_ue8m0_tensor(const at::Tensor& sf) { |
| | return deep_gemm::get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); |
| | } |
| |
|
| | at::Tensor dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( |
| | const at::Tensor& sf, |
| | const at::Tensor& ks_tensor, |
| | const std::vector<int64_t>& ks) { |
| | return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( |
| | sf, ks_tensor, vec64_to_vec32(ks)); |
| | } |
| |
|
| | #endif |
| |
|
| | int64_t dg_get_mk_alignment_for_contiguous_layout() { |
| | return static_cast<int64_t>(deep_gemm::get_mk_alignment_for_contiguous_layout()); |
| | } |
| |
|