deep-gemm / csrc /deep_gemm_impl.cpp
medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
#include <torch/torch.h>
#include "apis/attention.hpp"
#include "apis/einsum.hpp"
#include "apis/gemm.hpp"
#include "apis/hyperconnection.hpp"
#include "apis/layout.hpp"
#if DG_TENSORMAP_COMPATIBLE
#include "jit/compiler.hpp"
#endif
#include "jit/device_runtime.hpp"
// ---- Type conversion helpers ----
static std::optional<std::tuple<int, int, int>>
vec_to_tuple3(const std::optional<std::vector<int64_t>>& v) {
if (!v.has_value()) return std::nullopt;
const auto& vec = v.value();
return std::make_tuple(static_cast<int>(vec[0]),
static_cast<int>(vec[1]),
static_cast<int>(vec[2]));
}
static std::tuple<int, int, int>
vec_to_tuple3_req(const std::vector<int64_t>& v) {
return std::make_tuple(static_cast<int>(v[0]),
static_cast<int>(v[1]),
static_cast<int>(v[2]));
}
static std::optional<std::tuple<int, int>>
vec_to_tuple2(const std::optional<std::vector<int64_t>>& v) {
if (!v.has_value()) return std::nullopt;
const auto& vec = v.value();
return std::make_tuple(static_cast<int>(vec[0]),
static_cast<int>(vec[1]));
}
static std::vector<int>
vec64_to_vec32(const std::vector<int64_t>& v) {
return std::vector<int>(v.begin(), v.end());
}
// ---- Runtime APIs ----
void dg_init(const std::string& library_root_path,
const std::string& cuda_home_path,
const std::string& cutlass_include_path) {
#if DG_TENSORMAP_COMPATIBLE
deep_gemm::Compiler::prepare_init(library_root_path, cuda_home_path, cutlass_include_path);
deep_gemm::KernelRuntime::prepare_init(cuda_home_path);
#endif
}
void dg_set_num_sms(int64_t new_num_sms) {
deep_gemm::device_runtime->set_num_sms(static_cast<int>(new_num_sms));
}
int64_t dg_get_num_sms() {
return static_cast<int64_t>(deep_gemm::device_runtime->get_num_sms());
}
void dg_set_tc_util(int64_t new_tc_util) {
deep_gemm::device_runtime->set_tc_util(static_cast<int>(new_tc_util));
}
int64_t dg_get_tc_util() {
return static_cast<int64_t>(deep_gemm::device_runtime->get_tc_util());
}
// ---- cuBLASLt GEMMs ----
void dg_cublaslt_gemm_nt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_nt(a, b, d, c);
}
void dg_cublaslt_gemm_nn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_nn(a, b, d, c);
}
void dg_cublaslt_gemm_tn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_tn(a, b, d, c);
}
void dg_cublaslt_gemm_tt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_tt(a, b, d, c);
}
// ---- FP8/FP4 GEMMs ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_fp8_fp4_gemm_nt(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_nt(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_fp8_fp4_gemm_nn(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_nn(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_fp8_fp4_gemm_tn(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_tn(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_fp8_fp4_gemm_tt(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_tt(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_m_grouped_fp8_fp4_gemm_nt_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast,
bool use_psum_layout,
const std::optional<int64_t>& expected_m_for_psum_layout) {
std::optional<int> expected_m;
if (expected_m_for_psum_layout.has_value())
expected_m = static_cast<int>(expected_m_for_psum_layout.value());
deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_contiguous(
{a, sfa}, {b, sfb}, d, grouped_layout,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast, use_psum_layout, expected_m);
}
void dg_m_grouped_fp8_fp4_gemm_nn_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast,
bool use_psum_layout) {
deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nn_contiguous(
{a, sfa}, {b, sfb}, d, grouped_layout,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast, use_psum_layout);
}
void dg_m_grouped_fp8_fp4_gemm_nt_masked(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d, const at::Tensor& masked_m,
int64_t expected_m,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_masked(
{a, sfa}, {b, sfb}, d, masked_m, static_cast<int>(expected_m),
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_k_grouped_fp8_gemm_nt_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::vector<int64_t>& ks,
const at::Tensor& ks_tensor,
const std::optional<at::Tensor>& c,
const std::vector<int64_t>& recipe,
const std::string& compiled_dims) {
deep_gemm::gemm::k_grouped_fp8_gemm_nt_contiguous(
{a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c,
vec_to_tuple3_req(recipe), compiled_dims);
}
void dg_k_grouped_fp8_gemm_tn_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::vector<int64_t>& ks,
const at::Tensor& ks_tensor,
const std::optional<at::Tensor>& c,
const std::vector<int64_t>& recipe,
const std::string& compiled_dims) {
deep_gemm::gemm::k_grouped_fp8_gemm_tn_contiguous(
{a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c,
vec_to_tuple3_req(recipe), compiled_dims);
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- BF16 GEMMs ----
#if DG_TENSORMAP_COMPATIBLE
void dg_bf16_gemm_nt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_nt(a, b, d, c, compiled_dims);
}
void dg_bf16_gemm_nn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_nn(a, b, d, c, compiled_dims);
}
void dg_bf16_gemm_tn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_tn(a, b, d, c, compiled_dims);
}
void dg_bf16_gemm_tt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_tt(a, b, d, c, compiled_dims);
}
void dg_m_grouped_bf16_gemm_nt_contiguous(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::string& compiled_dims,
bool use_psum_layout,
const std::optional<int64_t>& expected_m_for_psum_layout) {
std::optional<int> expected_m;
if (expected_m_for_psum_layout.has_value())
expected_m = static_cast<int>(expected_m_for_psum_layout.value());
deep_gemm::gemm::m_grouped_bf16_gemm_nt_contiguous(
a, b, d, grouped_layout, compiled_dims, use_psum_layout, expected_m);
}
void dg_m_grouped_bf16_gemm_nn_contiguous(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::string& compiled_dims,
bool use_psum_layout) {
deep_gemm::gemm::m_grouped_bf16_gemm_nn_contiguous(
a, b, d, grouped_layout, compiled_dims, use_psum_layout);
}
void dg_m_grouped_bf16_gemm_nt_masked(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& masked_m,
int64_t expected_m,
const std::string& compiled_dims) {
deep_gemm::gemm::m_grouped_bf16_gemm_nt_masked(
a, b, d, masked_m, static_cast<int>(expected_m), compiled_dims);
}
void dg_k_grouped_bf16_gemm_tn_contiguous(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::vector<int64_t>& ks,
const at::Tensor& ks_tensor,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::k_grouped_bf16_gemm_tn_contiguous(
a, b, d, vec64_to_vec32(ks), ks_tensor, c, compiled_dims);
}
#endif // DG_TENSORMAP_COMPATIBLE
// ---- Einsum ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_einsum(const std::string& expr,
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
bool use_cublaslt) {
deep_gemm::einsum::einsum(expr, a, b, d, c, use_cublaslt);
}
void dg_fp8_einsum(const std::string& expr,
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::vector<int64_t>& recipe) {
deep_gemm::einsum::fp8_einsum(
expr, {a, sfa}, {b, sfb}, d, c, vec_to_tuple3_req(recipe));
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- Attention ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_fp8_gemm_nt_skip_head_mid(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::vector<int64_t>& head_splits,
const std::optional<std::vector<int64_t>>& recipe,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::attention::fp8_gemm_nt_skip_head_mid(
{a, sfa}, {b, sfb}, d, vec_to_tuple3_req(head_splits),
vec_to_tuple3(recipe), compiled_dims, disable_ue8m0_cast);
}
at::Tensor dg_fp8_mqa_logits(
const at::Tensor& q,
const at::Tensor& kv, const at::Tensor& kv_sf,
const at::Tensor& weights,
const at::Tensor& cu_seq_len_k_start,
const at::Tensor& cu_seq_len_k_end,
bool clean_logits,
int64_t max_seqlen_k) {
return deep_gemm::attention::fp8_mqa_logits(
q, {kv, kv_sf}, weights,
cu_seq_len_k_start, cu_seq_len_k_end,
clean_logits, static_cast<int>(max_seqlen_k));
}
at::Tensor dg_get_paged_mqa_logits_metadata(
const at::Tensor& context_lens,
int64_t block_kv, int64_t num_sms) {
return deep_gemm::attention::get_paged_mqa_logits_metadata(
context_lens, static_cast<int>(block_kv), static_cast<int>(num_sms));
}
at::Tensor dg_fp8_paged_mqa_logits(
const at::Tensor& q,
const at::Tensor& fused_kv_cache,
const at::Tensor& weights,
const at::Tensor& context_lens,
const at::Tensor& block_table,
const at::Tensor& schedule_meta,
int64_t max_context_len,
bool clean_logits) {
return deep_gemm::attention::fp8_paged_mqa_logits(
q, fused_kv_cache, weights, context_lens,
block_table, schedule_meta,
static_cast<int>(max_context_len), clean_logits);
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- Hyperconnection ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_tf32_hc_prenorm_gemm(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& sqr_sum,
const std::optional<int64_t>& num_splits) {
std::optional<int> ns;
if (num_splits.has_value())
ns = static_cast<int>(num_splits.value());
deep_gemm::hyperconnection::tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns);
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- Layout ----
#if DG_TENSORMAP_COMPATIBLE
at::Tensor dg_transform_sf_into_required_layout(
const at::Tensor& sf,
int64_t mn, int64_t k,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_ab,
const std::optional<int64_t>& num_groups,
bool is_sfa, bool disable_ue8m0_cast) {
std::optional<int> ng;
if (num_groups.has_value())
ng = static_cast<int>(num_groups.value());
return deep_gemm::layout::transform_sf_into_required_layout(
sf, static_cast<int>(mn), static_cast<int>(k),
vec_to_tuple3(recipe), vec_to_tuple2(recipe_ab),
ng, is_sfa, disable_ue8m0_cast);
}
int64_t dg_get_tma_aligned_size(int64_t x, int64_t element_size) {
return static_cast<int64_t>(
deep_gemm::get_tma_aligned_size(
static_cast<int>(x), static_cast<int>(element_size)));
}
at::Tensor dg_get_mn_major_tma_aligned_tensor(const at::Tensor& sf) {
return deep_gemm::get_mn_major_tma_aligned_tensor(sf);
}
at::Tensor dg_get_mn_major_tma_aligned_packed_ue8m0_tensor(const at::Tensor& sf) {
return deep_gemm::get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
}
at::Tensor dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
const at::Tensor& sf,
const at::Tensor& ks_tensor,
const std::vector<int64_t>& ks) {
return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
sf, ks_tensor, vec64_to_vec32(ks));
}
#endif // DG_TENSORMAP_COMPATIBLE
int64_t dg_get_mk_alignment_for_contiguous_layout() {
return static_cast<int64_t>(deep_gemm::get_mk_alignment_for_contiguous_layout());
}