File size: 8,468 Bytes
c67ae40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // ---- Runtime ----
  ops.def("init(str library_root_path, str cuda_home_path, str cutlass_include_path) -> ()");
  ops.impl("init", &dg_init_wrap);

  ops.def("set_num_sms(int new_num_sms) -> ()");
  ops.impl("set_num_sms", &dg_set_num_sms_wrap);

  ops.def("get_num_sms() -> int");
  ops.impl("get_num_sms", &dg_get_num_sms_wrap);

  ops.def("set_tc_util(int new_tc_util) -> ()");
  ops.impl("set_tc_util", &dg_set_tc_util_wrap);

  ops.def("get_tc_util() -> int");
  ops.impl("get_tc_util", &dg_get_tc_util_wrap);

  // ---- cuBLASLt GEMMs ----
  ops.def("cublaslt_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
  ops.impl("cublaslt_gemm_nt", torch::kCUDA, &dg_cublaslt_gemm_nt);

  ops.def("cublaslt_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
  ops.impl("cublaslt_gemm_nn", torch::kCUDA, &dg_cublaslt_gemm_nn);

  ops.def("cublaslt_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
  ops.impl("cublaslt_gemm_tn", torch::kCUDA, &dg_cublaslt_gemm_tn);

  ops.def("cublaslt_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()");
  ops.impl("cublaslt_gemm_tt", torch::kCUDA, &dg_cublaslt_gemm_tt);

  // ---- FP8/FP4 GEMMs ----
  ops.def("fp8_fp4_gemm_nt(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
  ops.impl("fp8_fp4_gemm_nt", torch::kCUDA, &dg_fp8_fp4_gemm_nt);

  ops.def("fp8_fp4_gemm_nn(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
  ops.impl("fp8_fp4_gemm_nn", torch::kCUDA, &dg_fp8_fp4_gemm_nn);

  ops.def("fp8_fp4_gemm_tn(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
  ops.impl("fp8_fp4_gemm_tn", torch::kCUDA, &dg_fp8_fp4_gemm_tn);

  ops.def("fp8_fp4_gemm_tt(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
  ops.impl("fp8_fp4_gemm_tt", torch::kCUDA, &dg_fp8_fp4_gemm_tt);

  ops.def("m_grouped_fp8_fp4_gemm_nt_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor grouped_layout, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout, int? expected_m_for_psum_layout) -> ()");
  ops.impl("m_grouped_fp8_fp4_gemm_nt_contiguous", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nt_contiguous);

  ops.def("m_grouped_fp8_fp4_gemm_nn_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor grouped_layout, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast, bool use_psum_layout) -> ()");
  ops.impl("m_grouped_fp8_fp4_gemm_nn_contiguous", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nn_contiguous);

  ops.def("m_grouped_fp8_fp4_gemm_nt_masked(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor masked_m, int expected_m, int[]? recipe, int[]? recipe_a, int[]? recipe_b, str compiled_dims, bool disable_ue8m0_cast) -> ()");
  ops.impl("m_grouped_fp8_fp4_gemm_nt_masked", torch::kCUDA, &dg_m_grouped_fp8_fp4_gemm_nt_masked);

  ops.def("k_grouped_fp8_gemm_nt_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, int[] recipe, str compiled_dims) -> ()");
  ops.impl("k_grouped_fp8_gemm_nt_contiguous", torch::kCUDA, &dg_k_grouped_fp8_gemm_nt_contiguous);

  ops.def("k_grouped_fp8_gemm_tn_contiguous(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, int[] recipe, str compiled_dims) -> ()");
  ops.impl("k_grouped_fp8_gemm_tn_contiguous", torch::kCUDA, &dg_k_grouped_fp8_gemm_tn_contiguous);

  // ---- BF16 GEMMs ----
  ops.def("bf16_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
  ops.impl("bf16_gemm_nt", torch::kCUDA, &dg_bf16_gemm_nt);

  ops.def("bf16_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
  ops.impl("bf16_gemm_nn", torch::kCUDA, &dg_bf16_gemm_nn);

  ops.def("bf16_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
  ops.impl("bf16_gemm_tn", torch::kCUDA, &dg_bf16_gemm_tn);

  ops.def("bf16_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c, str compiled_dims) -> ()");
  ops.impl("bf16_gemm_tt", torch::kCUDA, &dg_bf16_gemm_tt);

  ops.def("m_grouped_bf16_gemm_nt_contiguous(Tensor a, Tensor b, Tensor d, Tensor grouped_layout, str compiled_dims, bool use_psum_layout, int? expected_m_for_psum_layout) -> ()");
  ops.impl("m_grouped_bf16_gemm_nt_contiguous", torch::kCUDA, &dg_m_grouped_bf16_gemm_nt_contiguous);

  ops.def("m_grouped_bf16_gemm_nn_contiguous(Tensor a, Tensor b, Tensor d, Tensor grouped_layout, str compiled_dims, bool use_psum_layout) -> ()");
  ops.impl("m_grouped_bf16_gemm_nn_contiguous", torch::kCUDA, &dg_m_grouped_bf16_gemm_nn_contiguous);

  ops.def("m_grouped_bf16_gemm_nt_masked(Tensor a, Tensor b, Tensor d, Tensor masked_m, int expected_m, str compiled_dims) -> ()");
  ops.impl("m_grouped_bf16_gemm_nt_masked", torch::kCUDA, &dg_m_grouped_bf16_gemm_nt_masked);

  ops.def("k_grouped_bf16_gemm_tn_contiguous(Tensor a, Tensor b, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c, str compiled_dims) -> ()");
  ops.impl("k_grouped_bf16_gemm_tn_contiguous", torch::kCUDA, &dg_k_grouped_bf16_gemm_tn_contiguous);

  // ---- Einsum ----
  ops.def("einsum(str expr, Tensor a, Tensor b, Tensor d, Tensor? c, bool use_cublaslt) -> ()");
  ops.impl("einsum", torch::kCUDA, &dg_einsum);

  ops.def("fp8_einsum(str expr, Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, Tensor? c, int[] recipe) -> ()");
  ops.impl("fp8_einsum", torch::kCUDA, &dg_fp8_einsum);

  // ---- Attention ----
  ops.def("fp8_gemm_nt_skip_head_mid(Tensor a, Tensor sfa, Tensor b, Tensor sfb, Tensor d, int[] head_splits, int[]? recipe, str compiled_dims, bool disable_ue8m0_cast) -> ()");
  ops.impl("fp8_gemm_nt_skip_head_mid", torch::kCUDA, &dg_fp8_gemm_nt_skip_head_mid);

  ops.def("fp8_mqa_logits(Tensor q, Tensor kv, Tensor kv_sf, Tensor weights, Tensor cu_seq_len_k_start, Tensor cu_seq_len_k_end, bool clean_logits, int max_seqlen_k) -> Tensor");
  ops.impl("fp8_mqa_logits", torch::kCUDA, &dg_fp8_mqa_logits);

  ops.def("get_paged_mqa_logits_metadata(Tensor context_lens, int block_kv, int num_sms) -> Tensor");
  ops.impl("get_paged_mqa_logits_metadata", torch::kCUDA, &dg_get_paged_mqa_logits_metadata);

  ops.def("fp8_paged_mqa_logits(Tensor q, Tensor fused_kv_cache, Tensor weights, Tensor context_lens, Tensor block_table, Tensor schedule_meta, int max_context_len, bool clean_logits) -> Tensor");
  ops.impl("fp8_paged_mqa_logits", torch::kCUDA, &dg_fp8_paged_mqa_logits);

  // ---- Hyperconnection ----
  ops.def("tf32_hc_prenorm_gemm(Tensor a, Tensor b, Tensor d, Tensor sqr_sum, int? num_splits) -> ()");
  ops.impl("tf32_hc_prenorm_gemm", torch::kCUDA, &dg_tf32_hc_prenorm_gemm);

  // ---- Layout ----
  ops.def("transform_sf_into_required_layout(Tensor sf, int mn, int k, int[]? recipe, int[]? recipe_ab, int? num_groups, bool is_sfa, bool disable_ue8m0_cast) -> Tensor");
  ops.impl("transform_sf_into_required_layout", torch::kCUDA, &dg_transform_sf_into_required_layout);

  ops.def("get_tma_aligned_size(int x, int element_size) -> int");
  ops.impl("get_tma_aligned_size", &dg_get_tma_aligned_size);

  ops.def("get_mn_major_tma_aligned_tensor(Tensor sf) -> Tensor");
  ops.impl("get_mn_major_tma_aligned_tensor", torch::kCUDA, &dg_get_mn_major_tma_aligned_tensor);

  ops.def("get_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf) -> Tensor");
  ops.impl("get_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &dg_get_mn_major_tma_aligned_packed_ue8m0_tensor);

  ops.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf, Tensor ks_tensor, int[] ks) -> Tensor");
  ops.impl("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);

  ops.def("get_mk_alignment_for_contiguous_layout() -> int");
  ops.impl("get_mk_alignment_for_contiguous_layout", &dg_get_mk_alignment_for_contiguous_layout_wrap);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)