#include #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // ---- Runtime ---- ops.def("init(str library_root_path, str cuda_home_path, str cutlass_include_path) -> ()"); ops.impl("init", &dg_init_wrap); ops.def("set_num_sms(int new_num_sms) -> ()"); ops.impl("set_num_sms", &dg_set_num_sms_wrap); ops.def("get_num_sms() -> int"); ops.impl("get_num_sms", &dg_get_num_sms_wrap); ops.def("set_tc_util(int new_tc_util) -> ()"); ops.impl("set_tc_util", &dg_set_tc_util_wrap); ops.def("get_tc_util() -> int"); ops.impl("get_tc_util", &dg_get_tc_util_wrap); // ---- cuBLASLt GEMMs ---- ops.def("cublaslt_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()"); ops.impl("cublaslt_gemm_nt", torch::kCUDA, &dg_cublaslt_gemm_nt); ops.def("cublaslt_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()"); ops.impl("cublaslt_gemm_nn", torch::kCUDA, &dg_cublaslt_gemm_nn); ops.def("cublaslt_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()"); ops.impl("cublaslt_gemm_tn", torch::kCUDA, &dg_cublaslt_gemm_tn); ops.def("cublaslt_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()"); ops.impl("cublaslt_gemm_tt", torch::kCUDA, &dg_cublaslt_gemm_tt); // ---- FP8/FP4 GEMMs ---- ops.def("fp8_fp4_gemm_nt(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()"); ops.impl("fp8_fp4_gemm_nt", torch::kCUDA, &dg_fp8_fp4_gemm_nt); ops.def("fp8_fp4_gemm_nn(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()"); ops.impl("fp8_fp4_gemm_nn", torch::kCUDA, &dg_fp8_fp4_gemm_nn); ops.def("fp8_fp4_gemm_tn(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()"); ops.impl("fp8_fp4_gemm_tn", torch::kCUDA, &dg_fp8_fp4_gemm_tn); ops.def("fp8_fp4_gemm_tt(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()"); ops.impl("fp8_fp4_gemm_tt", torch::kCUDA, &dg_fp8_fp4_gemm_tt); ops.def("m_grouped_fp8_fp4_gemm_nt_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor grouped_layout, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout, int? expected_m_for_psum_layout) -> ()"); ops.impl("m_grouped_fp8_fp4_gemm_nt_contiguous", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nt_contiguous); ops.def("m_grouped_fp8_fp4_gemm_nn_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor grouped_layout, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout) -> ()"); ops.impl("m_grouped_fp8_fp4_gemm_nn_contiguous", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nn_contiguous); ops.def("m_grouped_fp8_fp4_gemm_nt_masked(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor masked_m, int expected_m, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()"); ops.impl("m_grouped_fp8_fp4_gemm_nt_masked", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nt_masked); ops.def("k_grouped_fp8_gemm_nt_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, int[] recipe, str compiled_dims) -> ()"); ops.impl("k_grouped_fp8_gemm_nt_contiguous", torch::kCUDA, &dg_k_grouped_fp8_gemm_nt_contiguous); ops.def("k_grouped_fp8_gemm_tn_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, int[] recipe, str compiled_dims) -> ()"); ops.impl("k_grouped_fp8_gemm_tn_contiguous", torch::kCUDA, &dg_k_grouped_fp8_gemm_tn_contiguous); // ---- BF16 GEMMs ---- ops.def("bf16_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()"); ops.impl("bf16_gemm_nt", torch::kCUDA, &dg_bf16_gemm_nt); ops.def("bf16_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()"); ops.impl("bf16_gemm_nn", torch::kCUDA, &dg_bf16_gemm_nn); ops.def("bf16_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()"); ops.impl("bf16_gemm_tn", torch::kCUDA, &dg_bf16_gemm_tn); ops.def("bf16_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()"); ops.impl("bf16_gemm_tt", torch::kCUDA, &dg_bf16_gemm_tt); ops.def("m_grouped_bf16_gemm_nt_contiguous(Tensor a, Tensor b, Tensor d, Tensor grouped_layout, str compiled_dims, bool use_psum_layout, int? expected_m_for_psum_layout) -> ()"); ops.impl("m_grouped_bf16_gemm_nt_contiguous", torch::kCUDA, &dg_m_grouped_bf16_gemm_nt_contiguous); ops.def("m_grouped_bf16_gemm_nn_contiguous(Tensor a, Tensor b, Tensor d, Tensor grouped_layout, str compiled_dims, bool use_psum_layout) -> ()"); ops.impl("m_grouped_bf16_gemm_nn_contiguous", torch::kCUDA, &dg_m_grouped_bf16_gemm_nn_contiguous); ops.def("m_grouped_bf16_gemm_nt_masked(Tensor a, Tensor b, Tensor d, Tensor masked_m, int expected_m, str compiled_dims) -> ()"); ops.impl("m_grouped_bf16_gemm_nt_masked", torch::kCUDA, &dg_m_grouped_bf16_gemm_nt_masked); ops.def("k_grouped_bf16_gemm_tn_contiguous(Tensor a, Tensor b, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, str compiled_dims) -> ()"); ops.impl("k_grouped_bf16_gemm_tn_contiguous", torch::kCUDA, &dg_k_grouped_bf16_gemm_tn_contiguous); // ---- Einsum ---- ops.def("einsum(str expr, Tensor a, Tensor b, Tensor d, Tensor? c, bool use_cublaslt) -> ()"); ops.impl("einsum", torch::kCUDA, &dg_einsum); ops.def("fp8_einsum(str expr, Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[] recipe) -> ()"); ops.impl("fp8_einsum", torch::kCUDA, &dg_fp8_einsum); // ---- Attention ---- ops.def("fp8_gemm_nt_skip_head_mid(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] head_splits, int[]? recipe, str compiled_dims, bool disable_ue8m0_cast) -> ()"); ops.impl("fp8_gemm_nt_skip_head_mid", torch::kCUDA, &dg_fp8_gemm_nt_skip_head_mid); ops.def("fp8_mqa_logits(Tensor q, Tensor kv, Tensor kv_sf, Tensor weights, Tensor cu_seq_len_k_start, Tensor cu_seq_len_k_end, bool clean_logits, int max_seqlen_k) -> Tensor"); ops.impl("fp8_mqa_logits", torch::kCUDA, &dg_fp8_mqa_logits); ops.def("get_paged_mqa_logits_metadata(Tensor context_lens, int block_kv, int num_sms) -> Tensor"); ops.impl("get_paged_mqa_logits_metadata", torch::kCUDA, &dg_get_paged_mqa_logits_metadata); ops.def("fp8_paged_mqa_logits(Tensor q, Tensor fused_kv_cache, Tensor weights, Tensor context_lens, Tensor block_table, Tensor schedule_meta, int max_context_len, bool clean_logits) -> Tensor"); ops.impl("fp8_paged_mqa_logits", torch::kCUDA, &dg_fp8_paged_mqa_logits); // ---- Hyperconnection ---- ops.def("tf32_hc_prenorm_gemm(Tensor a, Tensor b, Tensor d, Tensor sqr_sum, int? num_splits) -> ()"); ops.impl("tf32_hc_prenorm_gemm", torch::kCUDA, &dg_tf32_hc_prenorm_gemm); // ---- Layout ---- ops.def("transform_sf_into_required_layout(Tensor sf, int mn, int k, int[]? recipe, int[]? recipe_ab, int? num_groups, bool is_sfa, bool disable_ue8m0_cast) -> Tensor"); ops.impl("transform_sf_into_required_layout", torch::kCUDA, &dg_transform_sf_into_required_layout); ops.def("get_tma_aligned_size(int x, int element_size) -> int"); ops.impl("get_tma_aligned_size", &dg_get_tma_aligned_size); ops.def("get_mn_major_tma_aligned_tensor(Tensor sf) -> Tensor"); ops.impl("get_mn_major_tma_aligned_tensor", torch::kCUDA, &dg_get_mn_major_tma_aligned_tensor); ops.def("get_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf) -> Tensor"); ops.impl("get_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &dg_get_mn_major_tma_aligned_packed_ue8m0_tensor); ops.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf, Tensor ks_tensor, int[] ks) -> Tensor"); ops.impl("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); ops.def("get_mk_alignment_for_contiguous_layout() -> int"); ops.impl("get_mk_alignment_for_contiguous_layout", &dg_get_mk_alignment_for_contiguous_layout_wrap); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)