File size: 8,468 Bytes
c67ae40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | #include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// ---- Runtime ----
ops.def("init(str library_root_path, str cuda_home_path, str cutlass_include_path) -> ()");
ops.impl("init", &dg_init_wrap);
ops.def("set_num_sms(int new_num_sms) -> ()");
ops.impl("set_num_sms", &dg_set_num_sms_wrap);
ops.def("get_num_sms() -> int");
ops.impl("get_num_sms", &dg_get_num_sms_wrap);
ops.def("set_tc_util(int new_tc_util) -> ()");
ops.impl("set_tc_util", &dg_set_tc_util_wrap);
ops.def("get_tc_util() -> int");
ops.impl("get_tc_util", &dg_get_tc_util_wrap);
// ---- cuBLASLt GEMMs ----
ops.def("cublaslt_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
ops.impl("cublaslt_gemm_nt", torch::kCUDA, &dg_cublaslt_gemm_nt);
ops.def("cublaslt_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
ops.impl("cublaslt_gemm_nn", torch::kCUDA, &dg_cublaslt_gemm_nn);
ops.def("cublaslt_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
ops.impl("cublaslt_gemm_tn", torch::kCUDA, &dg_cublaslt_gemm_tn);
ops.def("cublaslt_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
ops.impl("cublaslt_gemm_tt", torch::kCUDA, &dg_cublaslt_gemm_tt);
// ---- FP8/FP4 GEMMs ----
ops.def("fp8_fp4_gemm_nt(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
ops.impl("fp8_fp4_gemm_nt", torch::kCUDA, &dg_fp8_fp4_gemm_nt);
ops.def("fp8_fp4_gemm_nn(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
ops.impl("fp8_fp4_gemm_nn", torch::kCUDA, &dg_fp8_fp4_gemm_nn);
ops.def("fp8_fp4_gemm_tn(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
ops.impl("fp8_fp4_gemm_tn", torch::kCUDA, &dg_fp8_fp4_gemm_tn);
ops.def("fp8_fp4_gemm_tt(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
ops.impl("fp8_fp4_gemm_tt", torch::kCUDA, &dg_fp8_fp4_gemm_tt);
ops.def("m_grouped_fp8_fp4_gemm_nt_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor grouped_layout, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout, int? expected_m_for_psum_layout) -> ()");
ops.impl("m_grouped_fp8_fp4_gemm_nt_contiguous", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nt_contiguous);
ops.def("m_grouped_fp8_fp4_gemm_nn_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor grouped_layout, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout) -> ()");
ops.impl("m_grouped_fp8_fp4_gemm_nn_contiguous", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nn_contiguous);
ops.def("m_grouped_fp8_fp4_gemm_nt_masked(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor masked_m, int expected_m, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
ops.impl("m_grouped_fp8_fp4_gemm_nt_masked", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nt_masked);
ops.def("k_grouped_fp8_gemm_nt_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, int[] recipe, str compiled_dims) -> ()");
ops.impl("k_grouped_fp8_gemm_nt_contiguous", torch::kCUDA, &dg_k_grouped_fp8_gemm_nt_contiguous);
ops.def("k_grouped_fp8_gemm_tn_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, int[] recipe, str compiled_dims) -> ()");
ops.impl("k_grouped_fp8_gemm_tn_contiguous", torch::kCUDA, &dg_k_grouped_fp8_gemm_tn_contiguous);
// ---- BF16 GEMMs ----
ops.def("bf16_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
ops.impl("bf16_gemm_nt", torch::kCUDA, &dg_bf16_gemm_nt);
ops.def("bf16_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
ops.impl("bf16_gemm_nn", torch::kCUDA, &dg_bf16_gemm_nn);
ops.def("bf16_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
ops.impl("bf16_gemm_tn", torch::kCUDA, &dg_bf16_gemm_tn);
ops.def("bf16_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
ops.impl("bf16_gemm_tt", torch::kCUDA, &dg_bf16_gemm_tt);
ops.def("m_grouped_bf16_gemm_nt_contiguous(Tensor a, Tensor b, Tensor d, Tensor grouped_layout, str compiled_dims, bool use_psum_layout, int? expected_m_for_psum_layout) -> ()");
ops.impl("m_grouped_bf16_gemm_nt_contiguous", torch::kCUDA, &dg_m_grouped_bf16_gemm_nt_contiguous);
ops.def("m_grouped_bf16_gemm_nn_contiguous(Tensor a, Tensor b, Tensor d, Tensor grouped_layout, str compiled_dims, bool use_psum_layout) -> ()");
ops.impl("m_grouped_bf16_gemm_nn_contiguous", torch::kCUDA, &dg_m_grouped_bf16_gemm_nn_contiguous);
ops.def("m_grouped_bf16_gemm_nt_masked(Tensor a, Tensor b, Tensor d, Tensor masked_m, int expected_m, str compiled_dims) -> ()");
ops.impl("m_grouped_bf16_gemm_nt_masked", torch::kCUDA, &dg_m_grouped_bf16_gemm_nt_masked);
ops.def("k_grouped_bf16_gemm_tn_contiguous(Tensor a, Tensor b, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, str compiled_dims) -> ()");
ops.impl("k_grouped_bf16_gemm_tn_contiguous", torch::kCUDA, &dg_k_grouped_bf16_gemm_tn_contiguous);
// ---- Einsum ----
ops.def("einsum(str expr, Tensor a, Tensor b, Tensor d, Tensor? c, bool use_cublaslt) -> ()");
ops.impl("einsum", torch::kCUDA, &dg_einsum);
ops.def("fp8_einsum(str expr, Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[] recipe) -> ()");
ops.impl("fp8_einsum", torch::kCUDA, &dg_fp8_einsum);
// ---- Attention ----
ops.def("fp8_gemm_nt_skip_head_mid(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] head_splits, int[]? recipe, str compiled_dims, bool disable_ue8m0_cast) -> ()");
ops.impl("fp8_gemm_nt_skip_head_mid", torch::kCUDA, &dg_fp8_gemm_nt_skip_head_mid);
ops.def("fp8_mqa_logits(Tensor q, Tensor kv, Tensor kv_sf, Tensor weights, Tensor cu_seq_len_k_start, Tensor cu_seq_len_k_end, bool clean_logits, int max_seqlen_k) -> Tensor");
ops.impl("fp8_mqa_logits", torch::kCUDA, &dg_fp8_mqa_logits);
ops.def("get_paged_mqa_logits_metadata(Tensor context_lens, int block_kv, int num_sms) -> Tensor");
ops.impl("get_paged_mqa_logits_metadata", torch::kCUDA, &dg_get_paged_mqa_logits_metadata);
ops.def("fp8_paged_mqa_logits(Tensor q, Tensor fused_kv_cache, Tensor weights, Tensor context_lens, Tensor block_table, Tensor schedule_meta, int max_context_len, bool clean_logits) -> Tensor");
ops.impl("fp8_paged_mqa_logits", torch::kCUDA, &dg_fp8_paged_mqa_logits);
// ---- Hyperconnection ----
ops.def("tf32_hc_prenorm_gemm(Tensor a, Tensor b, Tensor d, Tensor sqr_sum, int? num_splits) -> ()");
ops.impl("tf32_hc_prenorm_gemm", torch::kCUDA, &dg_tf32_hc_prenorm_gemm);
// ---- Layout ----
ops.def("transform_sf_into_required_layout(Tensor sf, int mn, int k, int[]? recipe, int[]? recipe_ab, int? num_groups, bool is_sfa, bool disable_ue8m0_cast) -> Tensor");
ops.impl("transform_sf_into_required_layout", torch::kCUDA, &dg_transform_sf_into_required_layout);
ops.def("get_tma_aligned_size(int x, int element_size) -> int");
ops.impl("get_tma_aligned_size", &dg_get_tma_aligned_size);
ops.def("get_mn_major_tma_aligned_tensor(Tensor sf) -> Tensor");
ops.impl("get_mn_major_tma_aligned_tensor", torch::kCUDA, &dg_get_mn_major_tma_aligned_tensor);
ops.def("get_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf) -> Tensor");
ops.impl("get_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &dg_get_mn_major_tma_aligned_packed_ue8m0_tensor);
ops.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf, Tensor ks_tensor, int[] ks) -> Tensor");
ops.impl("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
ops.def("get_mk_alignment_for_contiguous_layout() -> int");
ops.impl("get_mk_alignment_for_contiguous_layout", &dg_get_mk_alignment_for_contiguous_layout_wrap);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|